Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Tuple | |
import torch | |
from torch import nn | |
def preprocess(image_a: torch.Tensor, image_b: torch.Tensor, image_encoder: nn.Module, clip_mean: torch.Tensor, | |
clip_std: torch.Tensor, should_drop_cond: List[Tuple[bool, bool]] = None, concat_hidden_states=None, | |
image_list=None): | |
with torch.no_grad(): | |
image_list = [] if image_list is None else image_list | |
additional_list = [] | |
if image_a is not None: | |
additional_list.append(image_a) | |
if image_b is not None: | |
additional_list.append(image_b) | |
image_list = additional_list + image_list | |
embeds_list = [] | |
for image in image_list: | |
# If already is vector skip encoder | |
if len(image.shape) == 2: | |
image_embeds = image | |
else: | |
encoder_outs = image_encoder(image, output_hidden_states=False) | |
image_embeds = encoder_outs.image_embeds | |
image_embeds = (image_embeds - clip_mean) / clip_std | |
embeds_list.append(image_embeds.unsqueeze(1)) | |
if should_drop_cond is not None: | |
for b_ind in range(embeds_list[0].shape[0]): | |
should_drop_a, should_drop_b = should_drop_cond[b_ind] | |
if should_drop_a: | |
embeds_list[0][b_ind] = torch.zeros_like(embeds_list[0][b_ind]) | |
if should_drop_b and image_b is not None: | |
embeds_list[1][b_ind] = torch.zeros_like(embeds_list[1][b_ind]) | |
if concat_hidden_states is not None: | |
embeds_list.append(concat_hidden_states) | |
out_hidden_states = torch.concat(embeds_list, dim=1) | |
image_embeds = torch.zeros_like(embeds_list[0].squeeze(1)) | |
return image_embeds, out_hidden_states | |