Spaces:
Runtime error
Runtime error
""" | |
This is a modified version which only extract text embedding in HF Space. | |
See https://github.com/baaivision/Uni3D for source code. | |
Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings. | |
""" | |
import os | |
import sys | |
import open_clip | |
import torch | |
from huggingface_hub import hf_hub_download | |
sys.path.append('') | |
from feature_extractors import FeatureExtractor | |
from utils.tokenizer import SimpleTokenizer | |
class Uni3dEmbeddingEncoder(FeatureExtractor): | |
def __init__(self, cache_dir, **kwargs) -> None: | |
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz" | |
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin") | |
if not os.path.exists(clip_path): | |
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin", | |
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D") | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.tokenizer = SimpleTokenizer(bpe_path) | |
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path) | |
self.clip_model.to(self.device) | |
def encode_3D(self, data): | |
raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py") | |
def encode_text(self, input_text): | |
texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True) | |
if len(texts.shape) < 2: | |
texts = texts[None, ...] | |
class_embeddings = self.clip_model.encode_text(texts) | |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
return class_embeddings.float() | |
def encode_image(self, img_tensor_list): | |
image = img_tensor_list.to(device=self.device, non_blocking=True) | |
image_features = self.clip_model.encode_image(image) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.float() | |
def encode_query(self, query_list): | |
return self.encode_text(query_list) | |
def get_img_transform(self): | |
return self.preprocess |