Spaces:
Build error
Build error
import pickle | |
import fire | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
class EmbeddingExtractor(object): | |
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False): | |
from sentence_transformers import SentenceTransformer | |
lang2model = { | |
"zh": "distiluse-base-multilingual-cased", | |
"en": "bert-base-nli-mean-tokens" | |
} | |
lang = "zh" if zh else "en" | |
model = SentenceTransformer(lang2model[lang]) | |
self.extract(caption_file, model, output, dev) | |
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"): | |
from bert_serving.client import BertClient | |
client = BertClient(ip) | |
self.extract(caption_file, client, output, dev) | |
def extract(self, caption_file: str, model, output, dev: bool): | |
caption_df = pd.read_json(caption_file, dtype={"key": str}) | |
embeddings = {} | |
if dev: | |
with tqdm(total=caption_df.shape[0], ascii=True) as pbar: | |
for idx, row in caption_df.iterrows(): | |
caption = row["caption"] | |
key = row["key"] | |
cap_idx = row["caption_index"] | |
embedding = model.encode([caption]) | |
embedding = np.array(embedding).reshape(-1) | |
embeddings[f"{key}_{cap_idx}"] = embedding | |
pbar.update() | |
else: | |
dump = {} | |
with tqdm(total=caption_df.shape[0], ascii=True) as pbar: | |
for idx, row in caption_df.iterrows(): | |
key = row["key"] | |
caption = row["caption"] | |
value = np.array(model.encode([caption])).reshape(-1) | |
if key not in embeddings.keys(): | |
embeddings[key] = [value] | |
else: | |
embeddings[key].append(value) | |
pbar.update() | |
for key in embeddings: | |
dump[key] = np.stack(embeddings[key]) | |
embeddings = dump | |
with open(output, "wb") as f: | |
pickle.dump(embeddings, f) | |
def extract_sbert(self, | |
input_json: str, | |
output: str): | |
from sentence_transformers import SentenceTransformer | |
import json | |
import torch | |
from h5py import File | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SentenceTransformer("paraphrase-MiniLM-L6-v2") | |
model = model.to(device) | |
model.eval() | |
data = json.load(open(input_json))["audios"] | |
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store: | |
for sample in data: | |
audio_id = sample["audio_id"] | |
for cap in sample["captions"]: | |
cap_id = cap["cap_id"] | |
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"]) | |
pbar.update() | |
if __name__ == "__main__": | |
fire.Fire(EmbeddingExtractor) | |