LD-T3D / feature_extractors /uni3d_embedding_encoder.py
yuanze1024's picture
remove unused code
4c05bb3
raw
history blame
2.4 kB
"""
This is a modified version which only extract text embedding in HF Space.
See https://github.com/baaivision/Uni3D for source code.
Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings.
"""
import os
import sys
import open_clip
import torch
from huggingface_hub import hf_hub_download
sys.path.append('')
from feature_extractors import FeatureExtractor
from utils.tokenizer import SimpleTokenizer
class Uni3dEmbeddingEncoder(FeatureExtractor):
def __init__(self, cache_dir, **kwargs) -> None:
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
if not os.path.exists(clip_path):
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = SimpleTokenizer(bpe_path)
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
self.clip_model.to(self.device)
@torch.no_grad()
def encode_3D(self, data):
raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py")
@torch.no_grad()
def encode_text(self, input_text):
texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
if len(texts.shape) < 2:
texts = texts[None, ...]
class_embeddings = self.clip_model.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
return class_embeddings.float()
@torch.no_grad()
def encode_image(self, img_tensor_list):
image = img_tensor_list.to(device=self.device, non_blocking=True)
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.float()
def encode_query(self, query_list):
return self.encode_text(query_list)
def get_img_transform(self):
return self.preprocess