File size: 3,152 Bytes
8121fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pickle
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm


class EmbeddingExtractor(object):

    def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
        from sentence_transformers import SentenceTransformer
        lang2model = {
            "zh": "distiluse-base-multilingual-cased",
            "en": "bert-base-nli-mean-tokens"
        }
        lang = "zh" if zh else "en"
        model = SentenceTransformer(lang2model[lang])

        self.extract(caption_file, model, output, dev)

    def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
        from bert_serving.client import BertClient
        client = BertClient(ip)
        
        self.extract(caption_file, client, output, dev)

    def extract(self, caption_file: str, model, output, dev: bool):
        caption_df = pd.read_json(caption_file, dtype={"key": str})
        embeddings = {}

        if dev:
            with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
                for idx, row in caption_df.iterrows():
                    caption = row["caption"]
                    key = row["key"]
                    cap_idx = row["caption_index"]
                    embedding = model.encode([caption])
                    embedding = np.array(embedding).reshape(-1)
                    embeddings[f"{key}_{cap_idx}"] = embedding
                    pbar.update()

        else:
            dump = {}

            with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
                for idx, row in caption_df.iterrows():
                    key = row["key"]
                    caption = row["caption"]
                    value = np.array(model.encode([caption])).reshape(-1)

                    if key not in embeddings.keys():
                        embeddings[key] = [value]
                    else:
                        embeddings[key].append(value)

                    pbar.update()
                
            for key in embeddings:
                dump[key] = np.stack(embeddings[key])

            embeddings = dump

        with open(output, "wb") as f:
            pickle.dump(embeddings, f)
        
    def extract_sbert(self,
                      input_json: str,
                      output: str):
        from sentence_transformers import SentenceTransformer
        import json
        import torch
        from h5py import File
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
        model = model.to(device)
        model.eval()

        data = json.load(open(input_json))["audios"]
        with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
            for sample in data:
                audio_id = sample["audio_id"]
                for cap in sample["captions"]:
                    cap_id = cap["cap_id"]
                    store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
                pbar.update()


if __name__ == "__main__":
    fire.Fire(EmbeddingExtractor)