|
import os |
|
import torch |
|
import numpy as np |
|
from typing import BinaryIO, List |
|
|
|
from imagebind import imagebind_model |
|
from imagebind.models.imagebind_model import ModalityType |
|
from imagebind.models.multimodal_preprocessors import SimpleTokenizer, TextPreprocessor |
|
|
|
|
|
V2_URL = "https://huggingface.co/jondurbin/videobind-v0.2/resolve/main/videobind.pth" |
|
V2_PATH = "./.checkpoints/videobind-v0.2.pth" |
|
BPE_PATH = "./models/bpe_simple_vocab_16e6.txt.gz" |
|
TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH) |
|
LENGTH_TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=1024) |
|
TOKEN_CHUNK_SIZE = 74 |
|
|
|
def get_imagebind_v2(path: str=V2_PATH): |
|
if not os.path.isfile(path): |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
torch.hub.download_url_to_file(V2_URL, path, progress=True) |
|
imagebind_model = torch.load(path) |
|
return imagebind_model |
|
|
|
|
|
def load_and_transform_text(text, device): |
|
if text is None: |
|
return None |
|
tokens = [TOKENIZER(t).unsqueeze(0).to(device) for t in text] |
|
tokens = torch.cat(tokens, dim=0) |
|
return tokens |
|
|
|
def split_text_by_token_limit(text, tokenizer, max_tokens=TOKEN_CHUNK_SIZE): |
|
def fits_in_token_limit(text_segment): |
|
tokens = tokenizer(text_segment) |
|
tokens = tokens[tokens != 0][1:-1].tolist() |
|
return len(tokens) <= max_tokens |
|
|
|
def recursive_split(text, delimiters): |
|
if fits_in_token_limit(text): |
|
return [text] |
|
if not delimiters: |
|
return split_by_tokens(text) |
|
delimiter = delimiters[0] |
|
parts = text.split(delimiter) |
|
result = [] |
|
current_segment = "" |
|
for part in parts: |
|
candidate_segment = current_segment + (delimiter if current_segment else '') + part |
|
if fits_in_token_limit(candidate_segment): |
|
current_segment = candidate_segment |
|
else: |
|
if current_segment: |
|
result.append(current_segment) |
|
current_segment = part |
|
if current_segment: |
|
result.append(current_segment) |
|
final_result = [] |
|
for segment in result: |
|
if fits_in_token_limit(segment): |
|
final_result.append(segment) |
|
else: |
|
final_result.extend(recursive_split(segment, delimiters[1:])) |
|
return final_result |
|
|
|
def split_by_tokens(text): |
|
tokens = tokenizer(text) |
|
tokens = tokens[tokens != 0][1:-1].tolist() |
|
chunks = np.array_split(tokens, int(len(tokens) / max_tokens) or 1) |
|
return [ |
|
tokenizer.decode(segment_tokens) |
|
for segment_tokens in chunks |
|
] |
|
|
|
return recursive_split(text, ['\n', '.', '!', '?', ',', ' ']) |
|
|
|
def load_and_transform_text_chunks(text, device): |
|
if not text: |
|
return [] |
|
all_tokens = LENGTH_TOKENIZER(text) |
|
all_tokens = all_tokens[all_tokens != 0][1:-1].tolist() |
|
|
|
return [ |
|
load_and_transform_text([segment], device) |
|
for segment in split_text_by_token_limit(text, LENGTH_TOKENIZER) |
|
] |
|
|
|
|
|
class ImageBind: |
|
def __init__(self, device="cuda:0", v2=False): |
|
self.device = device |
|
self.v2 = v2 |
|
if v2: |
|
if not os.path.exists(V2_PATH): |
|
os.makedirs(os.path.dirname(V2_PATH), exist_ok=True) |
|
torch.hub.download_url_to_file( |
|
V2_URL, |
|
V2_PATH, |
|
progress=True, |
|
) |
|
self.imagebind = torch.load(V2_PATH) |
|
else: |
|
self.imagebind = imagebind_model.imagebind_huge(pretrained=True) |
|
self.imagebind.eval() |
|
self.imagebind.to(self.device) |
|
|
|
def generate_text_embeddings(self, text: str): |
|
if not self.v2: |
|
return self.imagebind({ |
|
ModalityType.TEXT: load_and_transform_text([text], self.device) |
|
})[ModalityType.TEXT] |
|
chunks = load_and_transform_text_chunks(text, self.device) |
|
embeddings = [ |
|
self.imagebind({ModalityType.TEXT: chunk})[ModalityType.TEXT] |
|
for chunk in chunks |
|
] |
|
return torch.mean(torch.stack(embeddings), dim=0) |
|
|
|
""" Deactivating full embeddings as they are not used in the current implementation |
|
def get_inputs(self, video_file: BinaryIO) -> dict: |
|
audio_file = video_utils.copy_audio(video_file.name) |
|
try: |
|
duration = video_utils.get_video_duration(video_file.name) |
|
video_data = data.load_and_transform_video_data( |
|
[video_file.name], |
|
self.device, |
|
) |
|
audio_data = data.load_and_transform_audio_data( |
|
[audio_file.name], |
|
self.device, |
|
) |
|
inputs = { |
|
ModalityType.VISION: video_data, |
|
ModalityType.AUDIO: audio_data, |
|
} |
|
return inputs |
|
finally: |
|
audio_file.close() |
|
|
|
@torch.no_grad() |
|
def embed(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddings: |
|
return_value = None |
|
for idx in range(len(descriptions)): |
|
inputs = self.get_inputs(video_files[idx]) |
|
embeddings = self.imagebind(inputs) |
|
text_embeddings = self.generate_text_embeddings(descriptions[idx]) |
|
if not return_value: |
|
return_value = Embeddings( |
|
video=embeddings[ModalityType.VISION], |
|
audio=embeddings[ModalityType.AUDIO], |
|
description=text_embeddings, |
|
) |
|
else: |
|
return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION])) |
|
return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO])) |
|
return_value.description = torch.cat((return_value.description, text_embeddings)) |
|
return return_value |
|
|
|
@torch.no_grad() |
|
def embed_only_video(self, video_files: List[BinaryIO]) -> Embeddings: |
|
video_filepaths = [video_file.name for video_file in video_files] |
|
durations = [video_utils.get_video_duration(f.name) for f in video_files] |
|
embeddings = self.imagebind({ |
|
ModalityType.VISION: [ |
|
data.load_and_transform_video_data( |
|
[video_filepaths[idx]], |
|
self.device, |
|
)[0] |
|
for idx in range(len(video_filepaths)) |
|
] |
|
}) |
|
return Embeddings( |
|
video=embeddings[ModalityType.VISION], |
|
) |
|
|
|
@torch.no_grad() |
|
def embed_video_and_text(self, video_files: List[BinaryIO], descriptions: List[str]) -> Embeddings: |
|
video_filepaths = [video_file.name for video_file in video_files] |
|
durations = [video_utils.get_video_duration(f.name) for f in video_files] |
|
embeddings = self.imagebind({ |
|
ModalityType.VISION: [ |
|
data.load_and_transform_video_data( |
|
[video_filepaths[idx]], |
|
self.device, |
|
)[0] |
|
for idx in range(len(video_filepaths)) |
|
], |
|
}) |
|
description_embeddings = torch.stack([ |
|
self.generate_text_embeddings(description) |
|
for description in descriptions |
|
]) |
|
return Embeddings( |
|
video=embeddings[ModalityType.VISION], |
|
description=description_embeddings, |
|
) |
|
|
|
@torch.no_grad() |
|
def embed_text(self, texts: List[str]) -> torch.Tensor: |
|
return_value = None |
|
for text in texts: |
|
emb = self.generate_text_embeddings(text) |
|
if not return_value: |
|
return_value = emb |
|
else: |
|
return_value = torch.cat((return_value, emb)) |
|
return return_value |
|
""" |
|
|
|
@torch.no_grad() |
|
def embed_text(self, texts: List[str]) -> torch.Tensor: |
|
embeddings = [] |
|
for text in texts: |
|
emb = self.generate_text_embeddings(text) |
|
embeddings.append(emb) |
|
|
|
if not embeddings: |
|
return None |
|
|
|
|
|
return torch.stack(embeddings, dim=0) |
|
|