AudioGPT / audio_to_text /captioning /utils /bert /create_sent_embedding.py
lmzjms's picture
Upload 46 files
8121fee
raw
history blame
3.15 kB
import pickle
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
class EmbeddingExtractor(object):
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
from sentence_transformers import SentenceTransformer
lang2model = {
"zh": "distiluse-base-multilingual-cased",
"en": "bert-base-nli-mean-tokens"
}
lang = "zh" if zh else "en"
model = SentenceTransformer(lang2model[lang])
self.extract(caption_file, model, output, dev)
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
from bert_serving.client import BertClient
client = BertClient(ip)
self.extract(caption_file, client, output, dev)
def extract(self, caption_file: str, model, output, dev: bool):
caption_df = pd.read_json(caption_file, dtype={"key": str})
embeddings = {}
if dev:
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
for idx, row in caption_df.iterrows():
caption = row["caption"]
key = row["key"]
cap_idx = row["caption_index"]
embedding = model.encode([caption])
embedding = np.array(embedding).reshape(-1)
embeddings[f"{key}_{cap_idx}"] = embedding
pbar.update()
else:
dump = {}
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
for idx, row in caption_df.iterrows():
key = row["key"]
caption = row["caption"]
value = np.array(model.encode([caption])).reshape(-1)
if key not in embeddings.keys():
embeddings[key] = [value]
else:
embeddings[key].append(value)
pbar.update()
for key in embeddings:
dump[key] = np.stack(embeddings[key])
embeddings = dump
with open(output, "wb") as f:
pickle.dump(embeddings, f)
def extract_sbert(self,
input_json: str,
output: str):
from sentence_transformers import SentenceTransformer
import json
import torch
from h5py import File
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
model = model.to(device)
model.eval()
data = json.load(open(input_json))["audios"]
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
for sample in data:
audio_id = sample["audio_id"]
for cap in sample["captions"]:
cap_id = cap["cap_id"]
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
pbar.update()
if __name__ == "__main__":
fire.Fire(EmbeddingExtractor)