Add the example usage.
Browse files- usage/README.md +43 -0
- usage/conf/generate_prototypes.yaml +9 -0
- usage/conf/run_linking.yaml +26 -0
- usage/generate_prototypes.py +61 -0
- usage/requirements.txt +4 -0
- usage/run_entity_linking.py +440 -0
- usage/utils.py +289 -0
usage/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Knowledge-Rich Self-Supervision (KRISS) for Biomedical Entity Linking
|
2 |
+
|
3 |
+
Usage code for the entity linking approach described in the following paper:
|
4 |
+
```bibtex
|
5 |
+
@article{kriss,
|
6 |
+
author = {Sheng Zhang, Hao Cheng, Shikhar Vashishth, Cliff Wong, Jinfeng Xiao, Xiaodong Liu, Tristan Naumann, Jianfeng Gao, Hoifung Poon},
|
7 |
+
title = {Knowledge-Rich Self-Supervision for Biomedical Entity Linking},
|
8 |
+
year = {2021},
|
9 |
+
url = {https://arxiv.org/abs/2112.07887},
|
10 |
+
eprinttype = {arXiv},
|
11 |
+
eprint = {2112.07887},
|
12 |
+
}
|
13 |
+
```
|
14 |
+
[https://arxiv.org/pdf/2112.07887.pdf](https://arxiv.org/pdf/2112.07887.pdf)
|
15 |
+
|
16 |
+
## Usage of KRISS for Entity Linking
|
17 |
+
|
18 |
+
Here, we use the [MedMentions](https://github.com/chanzuckerberg/MedMentions) data to show you how to 1) generate prototype embeddings, and 2) run entity linking.
|
19 |
+
|
20 |
+
(We are currently unable to release the self-supervised mention examples, because they requires UMLS and PubMed licenses.)
|
21 |
+
|
22 |
+
|
23 |
+
### 1. Create conda environment and install requirements
|
24 |
+
```bash
|
25 |
+
conda create -n kriss -y python=3.8 && conda activate kriss
|
26 |
+
pip install -r requirements.txt
|
27 |
+
```
|
28 |
+
|
29 |
+
### 2. Download the MedMentions dataset
|
30 |
+
|
31 |
+
```bash
|
32 |
+
git clone https://github.com/chanzuckerberg/MedMentions.git
|
33 |
+
```
|
34 |
+
|
35 |
+
### 3. Generate prototype embeddings
|
36 |
+
```bash
|
37 |
+
python generate_prototypes.py
|
38 |
+
```
|
39 |
+
|
40 |
+
### 4. Run entity linking
|
41 |
+
```bash
|
42 |
+
python run_entity_linking.py
|
43 |
+
```
|
usage/conf/generate_prototypes.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name_or_path: microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL
|
2 |
+
train_data:
|
3 |
+
_target_: utils.MedMentionsDataset
|
4 |
+
dataset_path: MedMentions/full/data/
|
5 |
+
split: train
|
6 |
+
batch_size: 256
|
7 |
+
max_length: 64
|
8 |
+
output_prototypes: prototypes/embeddings
|
9 |
+
output_name_cuis: prototypes/name_cuis
|
usage/conf/run_linking.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path to pretrained model and tokenizer
|
2 |
+
model_name_or_path: microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL
|
3 |
+
|
4 |
+
test_data:
|
5 |
+
_target_: utils.MedMentionsDataset
|
6 |
+
dataset_path: MedMentions/full/data/
|
7 |
+
split: test
|
8 |
+
|
9 |
+
# paths to encoded data
|
10 |
+
encoded_files: [
|
11 |
+
prototypes/embeddings
|
12 |
+
]
|
13 |
+
|
14 |
+
encoded_umls_files: []
|
15 |
+
|
16 |
+
entity_list_ids:
|
17 |
+
|
18 |
+
entity_list_names: prototypes/name_cuis
|
19 |
+
|
20 |
+
index_path:
|
21 |
+
|
22 |
+
seed: 12345
|
23 |
+
batch_size: 256
|
24 |
+
max_length: 64
|
25 |
+
num_retrievals: 100
|
26 |
+
top_ks: [1, 5, 50, 100]
|
usage/generate_prototypes.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is licensed under the license found in the
|
2 |
+
# LICENSE file in the root directory of this source tree.
|
3 |
+
|
4 |
+
"""
|
5 |
+
Command line tool that produces embeddings for a large set of entity mentions
|
6 |
+
based on the pretrained mention encoder.
|
7 |
+
"""
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import pathlib
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
import hydra
|
14 |
+
from omegaconf import DictConfig, OmegaConf
|
15 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
16 |
+
|
17 |
+
from utils import generate_vectors
|
18 |
+
|
19 |
+
|
20 |
+
# Setup logger
|
21 |
+
logger = logging.getLogger()
|
22 |
+
logger.setLevel(logging.INFO)
|
23 |
+
log_formatter = logging.Formatter(
|
24 |
+
"[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
25 |
+
)
|
26 |
+
console = logging.StreamHandler()
|
27 |
+
console.setFormatter(log_formatter)
|
28 |
+
logger.addHandler(console)
|
29 |
+
|
30 |
+
|
31 |
+
@hydra.main(config_path="conf", config_name="generate_prototypes", version_base=None)
|
32 |
+
def main(cfg: DictConfig):
|
33 |
+
logger.info("Configuration:")
|
34 |
+
logger.info("%s", OmegaConf.to_yaml(cfg))
|
35 |
+
|
36 |
+
config = AutoConfig.from_pretrained(cfg.model_name_or_path)
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
38 |
+
cfg.model_name_or_path,
|
39 |
+
use_fast=True,
|
40 |
+
)
|
41 |
+
encoder = AutoModel.from_pretrained(
|
42 |
+
cfg.model_name_or_path,
|
43 |
+
config=config
|
44 |
+
)
|
45 |
+
encoder.cuda()
|
46 |
+
encoder.eval()
|
47 |
+
|
48 |
+
ds = hydra.utils.instantiate(cfg.train_data)
|
49 |
+
data = generate_vectors(encoder, tokenizer, ds, cfg.batch_size, cfg.max_length, is_prototype=True)
|
50 |
+
pathlib.Path(os.path.dirname(cfg.output_prototypes)).mkdir(parents=True, exist_ok=True)
|
51 |
+
logger.info("Writing results to %s" % cfg.output_prototypes)
|
52 |
+
with open(cfg.output_prototypes, mode="wb") as f:
|
53 |
+
pickle.dump(data, f)
|
54 |
+
with open(cfg.output_name_cuis, 'w') as f:
|
55 |
+
for name, cuis in ds.name_to_cuis.items():
|
56 |
+
f.write('|'.join(cuis) + '||' + name + '\n')
|
57 |
+
logger.info("Total data processed %d. Written to %s", len(data), cfg.output_prototypes)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
main()
|
usage/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.17.0
|
2 |
+
torch==1.11
|
3 |
+
hydra-core==1.2.0
|
4 |
+
faiss-gpu==1.7.0
|
usage/run_entity_linking.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is licensed under the license found in the
|
2 |
+
# LICENSE file in the root directory of this source tree.
|
3 |
+
|
4 |
+
"""
|
5 |
+
Run entity linking
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import glob
|
10 |
+
import logging
|
11 |
+
import pathlib
|
12 |
+
import pickle
|
13 |
+
import time
|
14 |
+
import math
|
15 |
+
import multiprocessing
|
16 |
+
from typing import List, Tuple, Dict, Iterator, Set
|
17 |
+
from functools import partial
|
18 |
+
from multiprocessing.dummy import Pool
|
19 |
+
|
20 |
+
import hydra
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from omegaconf import DictConfig, OmegaConf
|
24 |
+
from torch import Tensor as T
|
25 |
+
from torch import nn
|
26 |
+
import faiss
|
27 |
+
|
28 |
+
from transformers import (
|
29 |
+
set_seed,
|
30 |
+
AutoConfig,
|
31 |
+
AutoTokenizer,
|
32 |
+
AutoModel,
|
33 |
+
PreTrainedTokenizer,
|
34 |
+
)
|
35 |
+
from utils import generate_vectors
|
36 |
+
|
37 |
+
|
38 |
+
# Setup logger
|
39 |
+
logger = logging.getLogger()
|
40 |
+
logger.setLevel(logging.INFO)
|
41 |
+
log_formatter = logging.Formatter(
|
42 |
+
"[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
43 |
+
)
|
44 |
+
console = logging.StreamHandler()
|
45 |
+
console.setFormatter(log_formatter)
|
46 |
+
logger.addHandler(console)
|
47 |
+
|
48 |
+
|
49 |
+
class DenseIndexer(object):
|
50 |
+
def __init__(self, buffer_size: int = 50000):
|
51 |
+
self.buffer_size = buffer_size
|
52 |
+
self.index_id_to_db_id = []
|
53 |
+
self.index = None
|
54 |
+
|
55 |
+
def init_index(self, vector_sz: int):
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
def index_data(self, data: List[Tuple[object, np.array]]):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def get_index_name(self):
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
def search_knn(
|
65 |
+
self, query_vectors: np.array, top_docs: int
|
66 |
+
) -> List[Tuple[List[object], List[float]]]:
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
def serialize(self, file: str):
|
70 |
+
logger.info("Serializing index to %s", file)
|
71 |
+
|
72 |
+
if os.path.isdir(file):
|
73 |
+
index_file = os.path.join(file, "index.dpr")
|
74 |
+
meta_file = os.path.join(file, "index_meta.dpr")
|
75 |
+
else:
|
76 |
+
index_file = file + ".index.dpr"
|
77 |
+
meta_file = file + ".index_meta.dpr"
|
78 |
+
|
79 |
+
faiss.write_index(self.index, index_file)
|
80 |
+
with open(meta_file, mode="wb") as f:
|
81 |
+
pickle.dump(self.index_id_to_db_id, f)
|
82 |
+
|
83 |
+
def get_files(self, path: str):
|
84 |
+
if os.path.isdir(path):
|
85 |
+
index_file = os.path.join(path, "index.dpr")
|
86 |
+
meta_file = os.path.join(path, "index_meta.dpr")
|
87 |
+
else:
|
88 |
+
index_file = path + ".index.dpr"
|
89 |
+
meta_file = path + ".index_meta.dpr"
|
90 |
+
return index_file, meta_file
|
91 |
+
|
92 |
+
def index_exists(self, path: str):
|
93 |
+
index_file, meta_file = self.get_files(path)
|
94 |
+
return os.path.isfile(index_file) and os.path.isfile(meta_file)
|
95 |
+
|
96 |
+
def deserialize(self, path: str):
|
97 |
+
logger.info("Loading index from %s", path)
|
98 |
+
index_file, meta_file = self.get_files(path)
|
99 |
+
|
100 |
+
self.index = faiss.read_index(index_file)
|
101 |
+
logger.info(
|
102 |
+
"Loaded index of type %s and size %d", type(self.index), self.index.ntotal
|
103 |
+
)
|
104 |
+
|
105 |
+
with open(meta_file, "rb") as reader:
|
106 |
+
self.index_id_to_db_id = pickle.load(reader)
|
107 |
+
assert (
|
108 |
+
len(self.index_id_to_db_id) == self.index.ntotal
|
109 |
+
), "Deserialized index_id_to_db_id should match faiss index size"
|
110 |
+
|
111 |
+
def _update_id_mapping(self, db_ids: List) -> int:
|
112 |
+
self.index_id_to_db_id.extend(db_ids)
|
113 |
+
return len(self.index_id_to_db_id)
|
114 |
+
|
115 |
+
|
116 |
+
class DenseFlatIndexer(DenseIndexer):
|
117 |
+
def __init__(self, buffer_size: int = 50000):
|
118 |
+
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
|
119 |
+
|
120 |
+
def init_index(self, vector_sz: int):
|
121 |
+
self.index = faiss.IndexFlatIP(vector_sz)
|
122 |
+
|
123 |
+
def index_data(self, data: List[Tuple[object, np.array]]):
|
124 |
+
n = len(data)
|
125 |
+
# indexing in batches is beneficial for many faiss index types
|
126 |
+
for i in range(0, n, self.buffer_size):
|
127 |
+
db_ids = [t[0] for t in data[i : i + self.buffer_size]]
|
128 |
+
vectors = [
|
129 |
+
np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]
|
130 |
+
]
|
131 |
+
vectors = np.concatenate(vectors, axis=0)
|
132 |
+
total_data = self._update_id_mapping(db_ids)
|
133 |
+
self.index.add(vectors)
|
134 |
+
logger.info("data indexed %d", total_data)
|
135 |
+
|
136 |
+
indexed_cnt = len(self.index_id_to_db_id)
|
137 |
+
logger.info("Total data indexed %d", indexed_cnt)
|
138 |
+
|
139 |
+
def search_knn(
|
140 |
+
self, query_vectors: np.array, top_docs: int, batch_size: int = 4096,
|
141 |
+
) -> List[Tuple[List[object], List[float]]]:
|
142 |
+
num_queries = query_vectors.shape[0]
|
143 |
+
scores, indexes = [], []
|
144 |
+
for start in range(0, num_queries, batch_size):
|
145 |
+
logger.info(f"Searched {start} queries.")
|
146 |
+
batch_vectors = query_vectors[start:start + batch_size]
|
147 |
+
batch_scores, batch_indexes = self.index.search(batch_vectors, top_docs)
|
148 |
+
scores.extend(batch_scores)
|
149 |
+
indexes.extend(batch_indexes)
|
150 |
+
# convert to external ids
|
151 |
+
db_ids = [
|
152 |
+
[self.index_id_to_db_id[i] for i in query_top_idxs]
|
153 |
+
for query_top_idxs in indexes
|
154 |
+
]
|
155 |
+
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
|
156 |
+
return result
|
157 |
+
|
158 |
+
def get_index_name(self):
|
159 |
+
return "flat_index"
|
160 |
+
|
161 |
+
|
162 |
+
def load_umls_data(files_patterns: List[str], candidate_ids: Dict = None) -> Dict:
|
163 |
+
input_paths = []
|
164 |
+
for pattern in files_patterns:
|
165 |
+
pattern_files = glob.glob(pattern)
|
166 |
+
input_paths.extend(pattern_files)
|
167 |
+
umls_data = {}
|
168 |
+
for file in sorted(input_paths):
|
169 |
+
logger.info("Reading encoded UMLS data from file %s", file)
|
170 |
+
with open(file, "rb") as reader:
|
171 |
+
for meta, vec in pickle.load(reader):
|
172 |
+
assert len(meta['cuis']) == 1, breakpoint()
|
173 |
+
cui = meta['cuis'][0]
|
174 |
+
if candidate_ids and cui not in candidate_ids:
|
175 |
+
continue
|
176 |
+
umls_data[cui] = (meta, vec)
|
177 |
+
logger.info(f"Loaded UMLS data = {len(umls_data)}.")
|
178 |
+
return umls_data
|
179 |
+
|
180 |
+
|
181 |
+
def iterate_encoded_files(
|
182 |
+
vector_files: list,
|
183 |
+
candidate_ids: Set = None,
|
184 |
+
umls_data: Dict = None,
|
185 |
+
)-> Iterator:
|
186 |
+
logger.info("Loading encoded prototype embeddings...")
|
187 |
+
proto_data = {}
|
188 |
+
for file in vector_files:
|
189 |
+
logger.info("Reading file %s", file)
|
190 |
+
with open(file, "rb") as reader:
|
191 |
+
for meta, vec in pickle.load(reader):
|
192 |
+
cuis = meta['cuis']
|
193 |
+
if candidate_ids and all(c not in candidate_ids for c in cuis):
|
194 |
+
continue
|
195 |
+
for cui in cuis:
|
196 |
+
proto_data.setdefault(cui, []).append((meta, vec))
|
197 |
+
# Concatenate prototype embs with additional knowledge embs from UMLS.
|
198 |
+
if umls_data is not None:
|
199 |
+
for cui, (meta, vec) in umls_data.items():
|
200 |
+
if cui in proto_data:
|
201 |
+
for _, _vec in proto_data.pop(cui):
|
202 |
+
extended_vec = np.concatenate((vec, _vec), axis=0)
|
203 |
+
yield (meta, extended_vec)
|
204 |
+
else:
|
205 |
+
extended_vec = np.concatenate((vec, np.zeros_like(vec)), axis=0)
|
206 |
+
yield (meta, extended_vec)
|
207 |
+
for cui in list(proto_data.keys()):
|
208 |
+
for meta, vec in proto_data.pop(cui):
|
209 |
+
extended_vec = np.concatenate((np.zeros_like(vec), vec), axis=0)
|
210 |
+
yield (meta, extended_vec)
|
211 |
+
assert len(proto_data) == 0
|
212 |
+
|
213 |
+
|
214 |
+
class DenseRetriever:
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
encoder: nn.Module,
|
218 |
+
tokenizer: PreTrainedTokenizer,
|
219 |
+
batch_size: int,
|
220 |
+
max_length: int,
|
221 |
+
):
|
222 |
+
self.encoder = encoder
|
223 |
+
self.tokenizer = tokenizer
|
224 |
+
self.batch_size = batch_size
|
225 |
+
self.max_length = max_length
|
226 |
+
|
227 |
+
def generate_mention_vectors(self, ds: torch.utils.data.Dataset) -> T:
|
228 |
+
self.encoder.eval()
|
229 |
+
return generate_vectors(
|
230 |
+
encoder=self.encoder,
|
231 |
+
tokenizer=self.tokenizer,
|
232 |
+
dataset=ds,
|
233 |
+
batch_size=self.batch_size,
|
234 |
+
max_length=self.max_length,
|
235 |
+
)
|
236 |
+
|
237 |
+
|
238 |
+
class FaissRetriever(DenseRetriever):
|
239 |
+
"""
|
240 |
+
Does entity retrieving over the provided index and encoder.
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
encoder: nn.Module,
|
246 |
+
tokenizer: PreTrainedTokenizer,
|
247 |
+
batch_size: int,
|
248 |
+
max_length: int,
|
249 |
+
index: DenseIndexer,
|
250 |
+
):
|
251 |
+
super().__init__(encoder, tokenizer, batch_size, max_length)
|
252 |
+
self.index = index
|
253 |
+
|
254 |
+
def index_encoded_data(
|
255 |
+
self,
|
256 |
+
vector_files: List[str],
|
257 |
+
buffer_size: int,
|
258 |
+
candidate_ids: Set = None,
|
259 |
+
umls_data: Dict = None,
|
260 |
+
):
|
261 |
+
"""
|
262 |
+
Indexes encoded data takes form a list of files
|
263 |
+
:param vector_files: a list of files
|
264 |
+
:param buffer_size: size of a buffer to send for the indexing at once
|
265 |
+
:return:
|
266 |
+
"""
|
267 |
+
buffer = []
|
268 |
+
for i, item in enumerate(
|
269 |
+
iterate_encoded_files(vector_files, candidate_ids, umls_data)
|
270 |
+
):
|
271 |
+
buffer.append(item)
|
272 |
+
if 0 < buffer_size == len(buffer):
|
273 |
+
self.index.index_data(buffer)
|
274 |
+
buffer = []
|
275 |
+
self.index.index_data(buffer)
|
276 |
+
logger.info("Data indexing completed.")
|
277 |
+
|
278 |
+
def get_top_hits(
|
279 |
+
self, mention_vectors: np.array, top_k: int = 100
|
280 |
+
) -> List[Tuple[List[object], List[float]]]:
|
281 |
+
"""
|
282 |
+
Does the retrieval of the best matching given the mention vectors batch
|
283 |
+
"""
|
284 |
+
time0 = time.time()
|
285 |
+
search = partial(
|
286 |
+
self.index.search_knn,
|
287 |
+
top_docs=top_k,
|
288 |
+
)
|
289 |
+
results = []
|
290 |
+
num_processes = multiprocessing.cpu_count()
|
291 |
+
shard_size = math.ceil(mention_vectors.shape[0] / num_processes)
|
292 |
+
shards = []
|
293 |
+
for i in range(0, mention_vectors.shape[0], shard_size):
|
294 |
+
shards.append(mention_vectors[i:i + shard_size])
|
295 |
+
with Pool(processes=num_processes) as pool:
|
296 |
+
it = pool.map(search, shards)
|
297 |
+
for ret in it:
|
298 |
+
results += ret
|
299 |
+
# results = self.index.search_knn(mention_vectors, top_k)
|
300 |
+
logger.info("index search time: %f sec.", time.time() - time0)
|
301 |
+
self.index = None
|
302 |
+
return results
|
303 |
+
|
304 |
+
|
305 |
+
def hit(pred: List[str], gold: List[str]) -> bool:
|
306 |
+
return all(p in gold for p in pred)
|
307 |
+
|
308 |
+
|
309 |
+
def dedup_ids(ids: List[Dict]) -> List[Dict]:
|
310 |
+
deduped_ids = []
|
311 |
+
seen_cuis = set()
|
312 |
+
for d in ids:
|
313 |
+
if all(cui in seen_cuis for cui in d['cuis']):
|
314 |
+
continue
|
315 |
+
seen_cuis.update(d['cuis'])
|
316 |
+
deduped_ids.append(d)
|
317 |
+
return deduped_ids
|
318 |
+
|
319 |
+
|
320 |
+
def evaluate(
|
321 |
+
ds: torch.utils.data.Dataset,
|
322 |
+
result_ent_ids: List[Tuple[List[object], List[float]]],
|
323 |
+
lookup_table: str,
|
324 |
+
top_ks: List[int] = (1, 5, 50, 100),
|
325 |
+
) -> List[Dict]:
|
326 |
+
lut = {}
|
327 |
+
with open(lookup_table, encoding='utf-8') as f:
|
328 |
+
for ln in f:
|
329 |
+
cuis, name = ln.strip().split('||')
|
330 |
+
cuis = cuis.split('|')
|
331 |
+
lut[name] = cuis
|
332 |
+
|
333 |
+
n = len(ds)
|
334 |
+
top_k_hits = {top_k: 0 for top_k in top_ks}
|
335 |
+
for i in range(len(result_ent_ids)):
|
336 |
+
d = ds[i]
|
337 |
+
ids, _ = result_ent_ids[i]
|
338 |
+
ids = dedup_ids(ids)
|
339 |
+
ids = ids[:max(top_ks)]
|
340 |
+
candidates = [
|
341 |
+
{'cuis': eid['cuis'], 'hit': int(hit(pred=eid['cuis'], gold=d.cuis))}
|
342 |
+
for eid in ids
|
343 |
+
]
|
344 |
+
lut_cuis = lut.get(d.mention, [])
|
345 |
+
if len(lut_cuis) == 1:
|
346 |
+
# If the mention only has one ID in the look up table,
|
347 |
+
# we use the ID as the top prediction.
|
348 |
+
candidates.insert(
|
349 |
+
0,
|
350 |
+
{'cuis': lut_cuis, 'hit': int(hit(pred=lut_cuis, gold=d.cuis))}
|
351 |
+
)
|
352 |
+
for top_k in top_k_hits:
|
353 |
+
if any(c['hit'] for c in candidates[:top_k]):
|
354 |
+
top_k_hits[top_k] += 1
|
355 |
+
|
356 |
+
top_k_acc = {top_k: v / n for top_k, v in top_k_hits.items()}
|
357 |
+
logger.info("Top-k accuracy %s", top_k_acc)
|
358 |
+
|
359 |
+
|
360 |
+
@hydra.main(config_path="conf", config_name="run_linking", version_base=None)
|
361 |
+
def main(cfg: DictConfig):
|
362 |
+
set_seed(cfg.seed)
|
363 |
+
|
364 |
+
logger.info("Configuration:")
|
365 |
+
logger.info("%s", OmegaConf.to_yaml(cfg))
|
366 |
+
|
367 |
+
# Load pretrained.
|
368 |
+
config = AutoConfig.from_pretrained(cfg.model_name_or_path)
|
369 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
370 |
+
cfg.model_name_or_path,
|
371 |
+
use_fast=True,
|
372 |
+
)
|
373 |
+
encoder = AutoModel.from_pretrained(
|
374 |
+
cfg.model_name_or_path,
|
375 |
+
config=config
|
376 |
+
)
|
377 |
+
encoder.cuda()
|
378 |
+
encoder.eval()
|
379 |
+
vector_size = config.hidden_size
|
380 |
+
logger.info("Encoder vector_size=%d", vector_size)
|
381 |
+
|
382 |
+
# Load test data.
|
383 |
+
ds = hydra.utils.instantiate(cfg.test_data)
|
384 |
+
|
385 |
+
# Init indexer.
|
386 |
+
index = DenseFlatIndexer()
|
387 |
+
index_buffer_sz = index.buffer_size
|
388 |
+
index.init_index(vector_size * 2)
|
389 |
+
|
390 |
+
# candidate ids
|
391 |
+
candidate_ids = None
|
392 |
+
if cfg.entity_list_ids:
|
393 |
+
with open(cfg.entity_list_ids, encoding='utf-8') as f:
|
394 |
+
candidate_ids = set(f.read().split('\n'))
|
395 |
+
|
396 |
+
# Start indexing
|
397 |
+
input_paths = []
|
398 |
+
for pattern in cfg.encoded_files:
|
399 |
+
pattern_files = glob.glob(pattern)
|
400 |
+
input_paths.extend(pattern_files)
|
401 |
+
input_paths = sorted(set(input_paths))
|
402 |
+
|
403 |
+
retriever = FaissRetriever(
|
404 |
+
encoder, tokenizer, cfg.batch_size, cfg.max_length, index)
|
405 |
+
mentions_tensor = retriever.generate_mention_vectors(ds)
|
406 |
+
|
407 |
+
# Load UMLS knowledge
|
408 |
+
umls_data = None
|
409 |
+
if cfg.encoded_umls_files:
|
410 |
+
umls_data = load_umls_data(cfg.encoded_umls_files, candidate_ids)
|
411 |
+
|
412 |
+
index_path = cfg.index_path
|
413 |
+
if index_path and index.index_exists(index_path):
|
414 |
+
logger.info("Index path: %s", index_path)
|
415 |
+
retriever.index.deserialize(index_path)
|
416 |
+
else:
|
417 |
+
logger.info("Indexing encoded data from files: %s", input_paths)
|
418 |
+
retriever.index_encoded_data(
|
419 |
+
vector_files=input_paths,
|
420 |
+
buffer_size=index_buffer_sz,
|
421 |
+
candidate_ids=candidate_ids,
|
422 |
+
umls_data=umls_data,
|
423 |
+
)
|
424 |
+
if index_path:
|
425 |
+
pathlib.Path(os.path.dirname(index_path)).mkdir(
|
426 |
+
parents=True, exist_ok=True)
|
427 |
+
retriever.index.serialize(index_path)
|
428 |
+
|
429 |
+
# Encode test data.
|
430 |
+
mentions_tensor = torch.cat([mentions_tensor, mentions_tensor], dim=1)
|
431 |
+
|
432 |
+
# To get k different entities, we retrieve 32 * k mentions and then dedup.
|
433 |
+
top_ids_and_scores = retriever.get_top_hits(
|
434 |
+
mentions_tensor.numpy(), cfg.num_retrievals * 32)
|
435 |
+
|
436 |
+
evaluate(ds, top_ids_and_scores, cfg.entity_list_names)
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
main()
|
usage/utils.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is licensed under the license found in the
|
2 |
+
# LICENSE file in the root directory of this source tree.
|
3 |
+
|
4 |
+
|
5 |
+
from typing import List, Dict
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
import gzip
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import Tensor as T
|
15 |
+
from transformers import PreTrainedTokenizer
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger()
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Mention:
|
23 |
+
cui: str
|
24 |
+
start: int
|
25 |
+
end: int
|
26 |
+
text: str
|
27 |
+
types: str
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class ContextualMention:
|
32 |
+
mention: str
|
33 |
+
cuis: List[str]
|
34 |
+
ctx_l: str
|
35 |
+
ctx_r: str
|
36 |
+
|
37 |
+
def to_tensor(self, tokenizer: PreTrainedTokenizer, max_length: int) -> T:
|
38 |
+
ctx_l_ids = tokenizer.encode(
|
39 |
+
text=self.ctx_l,
|
40 |
+
add_special_tokens=False,
|
41 |
+
max_length=max_length,
|
42 |
+
truncation=True,
|
43 |
+
)
|
44 |
+
ctx_r_ids = tokenizer.encode(
|
45 |
+
text=self.ctx_r,
|
46 |
+
add_special_tokens=False,
|
47 |
+
max_length=max_length,
|
48 |
+
truncation=True,
|
49 |
+
)
|
50 |
+
mention_ids = tokenizer.encode(
|
51 |
+
text=self.mention,
|
52 |
+
add_special_tokens=False,
|
53 |
+
max_length=max_length,
|
54 |
+
truncation=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
# Concatenate context and mention to the max length.
|
58 |
+
token_ids = tokenizer.convert_tokens_to_ids(['<ENT>']) + mention_ids \
|
59 |
+
+ tokenizer.convert_tokens_to_ids(['</ENT>'])
|
60 |
+
max_ctx_len = max_length - len(token_ids) - 2 # Exclude [CLS] and [SEP]
|
61 |
+
max_ctx_l_len = max_ctx_len // 2
|
62 |
+
max_ctx_r_len = max_ctx_len - max_ctx_l_len
|
63 |
+
if len(ctx_l_ids) < max_ctx_l_len and len(ctx_r_ids) < max_ctx_r_len:
|
64 |
+
token_ids = ctx_l_ids + token_ids + ctx_r_ids
|
65 |
+
elif len(ctx_l_ids) >= max_ctx_l_len and len(ctx_r_ids) >= max_ctx_r_len:
|
66 |
+
token_ids = ctx_l_ids[-max_ctx_l_len:] + token_ids \
|
67 |
+
+ ctx_r_ids[:max_ctx_r_len]
|
68 |
+
elif len(ctx_l_ids) >= max_ctx_l_len:
|
69 |
+
ctx_l_len = max_ctx_len - len(ctx_r_ids)
|
70 |
+
token_ids = ctx_l_ids[-ctx_l_len:] + token_ids + ctx_r_ids
|
71 |
+
else:
|
72 |
+
ctx_r_len = max_ctx_len - len(ctx_l_ids)
|
73 |
+
token_ids = ctx_l_ids + token_ids + ctx_r_ids[:ctx_r_len]
|
74 |
+
|
75 |
+
token_ids = [tokenizer.cls_token_id] + token_ids
|
76 |
+
|
77 |
+
# The above snippet doesn't guarantee the max length limit.
|
78 |
+
token_ids = token_ids[:max_length - 1] + [tokenizer.sep_token_id]
|
79 |
+
|
80 |
+
if len(token_ids) < max_length:
|
81 |
+
token_ids = token_ids + [tokenizer.pad_token_id] * (max_length - len(token_ids))
|
82 |
+
|
83 |
+
return torch.tensor(token_ids)
|
84 |
+
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class Document:
|
88 |
+
id: str = None
|
89 |
+
title: str = None
|
90 |
+
abstract: str = None
|
91 |
+
mentions: List[Mention] = field(default_factory=list)
|
92 |
+
|
93 |
+
def concatenate_text(self) -> str:
|
94 |
+
return ' '.join([self.title, self.abstract])
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_PubTator(cls, path: str, split_path_prefix: str) -> Dict[str, List]:
|
98 |
+
docs = []
|
99 |
+
with gzip.open(path, 'rb') as f:
|
100 |
+
for b in f.read().decode().strip().split('\n\n'):
|
101 |
+
d = cls()
|
102 |
+
s = ''
|
103 |
+
for i, ln in enumerate(b.split('\n')):
|
104 |
+
if i == 0:
|
105 |
+
id, type, text = ln.strip().split('|', 2)
|
106 |
+
assert type == 't'
|
107 |
+
d.id, d.title = id, text
|
108 |
+
elif i == 1:
|
109 |
+
id, type, text = ln.strip().split('|', 2)
|
110 |
+
assert type == 'a'
|
111 |
+
assert d.id == id
|
112 |
+
d.abstract = text
|
113 |
+
s = d.concatenate_text()
|
114 |
+
else:
|
115 |
+
items = ln.strip().split('\t')
|
116 |
+
assert d.id == items[0]
|
117 |
+
cui = items[5].split('UMLS:')[-1]
|
118 |
+
assert len(cui) == 8, breakpoint()
|
119 |
+
m = Mention(
|
120 |
+
cui=cui,
|
121 |
+
start=int(items[1]),
|
122 |
+
end=int(items[2]),
|
123 |
+
text=items[3],
|
124 |
+
types=items[4].split(',')
|
125 |
+
)
|
126 |
+
assert m.text == s[m.start: m.end]
|
127 |
+
d.mentions.append(m)
|
128 |
+
docs.append(d)
|
129 |
+
dataset = split_dataset(docs, split_path_prefix)
|
130 |
+
print_dataset_stats(dataset)
|
131 |
+
return dataset
|
132 |
+
|
133 |
+
def to_contextual_mentions(self, max_length: int = 64) -> List[ContextualMention]:
|
134 |
+
text = self.concatenate_text()
|
135 |
+
mentions = []
|
136 |
+
for m in self.mentions:
|
137 |
+
assert m.text == text[m.start:m.end]
|
138 |
+
# Context
|
139 |
+
ctx_l, ctx_r = text[:m.start].strip().split(), text[m.end:].strip().split()
|
140 |
+
ctx_l, ctx_r = ' '.join(ctx_l[-max_length:]), ' '.join(ctx_r[:max_length])
|
141 |
+
cm = ContextualMention(
|
142 |
+
mention=m.text,
|
143 |
+
cuis=[m.cui],
|
144 |
+
ctx_l=ctx_l,
|
145 |
+
ctx_r=ctx_r,
|
146 |
+
)
|
147 |
+
mentions.append(cm)
|
148 |
+
return mentions
|
149 |
+
|
150 |
+
|
151 |
+
def split_dataset(docs: List, split_path_prefix: str) -> Dict[str, List]:
|
152 |
+
split_kv = {'train': 'trng', 'dev': 'dev', 'test': 'test'}
|
153 |
+
id_to_split = {}
|
154 |
+
dataset = {}
|
155 |
+
for k, v in split_kv.items():
|
156 |
+
dataset[k] = []
|
157 |
+
path = split_path_prefix + v + '.txt'
|
158 |
+
for i in open(path, encoding='utf-8').read().strip().split('\n'):
|
159 |
+
assert i not in id_to_split, breakpoint()
|
160 |
+
id_to_split[i] = k
|
161 |
+
for doc in docs:
|
162 |
+
split = id_to_split[doc.id]
|
163 |
+
dataset[split].append(doc)
|
164 |
+
return dataset
|
165 |
+
|
166 |
+
|
167 |
+
def print_dataset_stats(dataset: Dict[str, List[Document]]) -> None:
|
168 |
+
all_docs = []
|
169 |
+
for v in dataset.values():
|
170 |
+
all_docs.extend(v)
|
171 |
+
for split, docs in {'all': all_docs, **dataset}.items():
|
172 |
+
logger.info(f"***** {split} *****")
|
173 |
+
logger.info(f"Documents: {len(docs)}")
|
174 |
+
logger.info(f"Mentions: {sum(len(d.mentions) for d in docs)}")
|
175 |
+
cuis = set()
|
176 |
+
for d in docs:
|
177 |
+
for m in d.mentions:
|
178 |
+
cuis.add(m.cui)
|
179 |
+
logger.info(f"Mentioned concepts: {len(cuis)}")
|
180 |
+
|
181 |
+
|
182 |
+
class MedMentionsDataset(torch.utils.data.Dataset):
|
183 |
+
|
184 |
+
def __init__(self, dataset_path: str, split: str) -> None:
|
185 |
+
super().__init__()
|
186 |
+
self.dataset_path = dataset_path
|
187 |
+
self.docs = Document.from_PubTator(
|
188 |
+
path=os.path.join(self.dataset_path, 'corpus_pubtator.txt.gz'),
|
189 |
+
split_path_prefix=os.path.join(self.dataset_path, 'corpus_pubtator_pmids_')
|
190 |
+
)[split]
|
191 |
+
self.mentions = []
|
192 |
+
self.name_to_cuis = {}
|
193 |
+
self._post_init()
|
194 |
+
|
195 |
+
def _post_init(self):
|
196 |
+
for d in self.docs:
|
197 |
+
self.mentions.extend(d.to_contextual_mentions())
|
198 |
+
for m in self.mentions:
|
199 |
+
if m.mention not in self.name_to_cuis:
|
200 |
+
self.name_to_cuis[m.mention] = set()
|
201 |
+
self.name_to_cuis[m.mention].update(m.cuis)
|
202 |
+
|
203 |
+
def __getitem__(self, index: int) -> ContextualMention:
|
204 |
+
return self.mentions[index]
|
205 |
+
|
206 |
+
def __len__(self) -> int:
|
207 |
+
return len(self.mentions)
|
208 |
+
|
209 |
+
|
210 |
+
class PreprocessedDataset(torch.utils.data.Dataset):
|
211 |
+
|
212 |
+
def __init__(self, dataset_path: str) -> None:
|
213 |
+
super().__init__()
|
214 |
+
self.file = dataset_path
|
215 |
+
self.data = []
|
216 |
+
self.load_data()
|
217 |
+
|
218 |
+
def load_data(self) -> None:
|
219 |
+
with open(self.file, encoding='utf-8') as f:
|
220 |
+
logger.info("Reading file %s" % self.file)
|
221 |
+
for ln in f:
|
222 |
+
if ln.strip():
|
223 |
+
self.data.append(json.loads(ln))
|
224 |
+
logger.info("Loaded data size: {}".format(len(self.data)))
|
225 |
+
|
226 |
+
def __getitem__(self, index: int) -> ContextualMention:
|
227 |
+
d = self.data[index]
|
228 |
+
return ContextualMention(
|
229 |
+
ctx_l=d['context_left'],
|
230 |
+
ctx_r=d['context_right'],
|
231 |
+
mention=d['mention'],
|
232 |
+
cuis=d['cuis'],
|
233 |
+
)
|
234 |
+
|
235 |
+
def __len__(self) -> int:
|
236 |
+
return len(self.data)
|
237 |
+
|
238 |
+
|
239 |
+
def generate_vectors(
|
240 |
+
encoder: torch.nn.Module,
|
241 |
+
tokenizer: PreTrainedTokenizer,
|
242 |
+
dataset: torch.utils.data.Dataset,
|
243 |
+
batch_size: int,
|
244 |
+
max_length: int,
|
245 |
+
is_prototype: bool = False,
|
246 |
+
):
|
247 |
+
n = len(dataset)
|
248 |
+
total = 0
|
249 |
+
results = []
|
250 |
+
start_time = time.time()
|
251 |
+
logger.info("Start encoding...")
|
252 |
+
for i, batch_start in enumerate(range(0, n, batch_size)):
|
253 |
+
batch = [dataset[i] for i in range(batch_start, min(n, batch_start + batch_size))]
|
254 |
+
batch_token_tensors = [m.to_tensor(tokenizer, max_length) for m in batch]
|
255 |
+
|
256 |
+
ids_batch = torch.stack(batch_token_tensors, dim=0).cuda()
|
257 |
+
seg_batch = torch.zeros_like(ids_batch)
|
258 |
+
attn_mask = (ids_batch != tokenizer.pad_token_id)
|
259 |
+
|
260 |
+
with torch.inference_mode():
|
261 |
+
out = encoder(
|
262 |
+
input_ids=ids_batch,
|
263 |
+
token_type_ids=seg_batch,
|
264 |
+
attention_mask=attn_mask
|
265 |
+
)
|
266 |
+
out = out[0][:, 0, :]
|
267 |
+
out = out.cpu()
|
268 |
+
|
269 |
+
num_mentions = out.size(0)
|
270 |
+
total += num_mentions
|
271 |
+
|
272 |
+
if is_prototype:
|
273 |
+
meta_batch = [{'cuis': m.cuis} for m in batch]
|
274 |
+
assert len(meta_batch) == num_mentions
|
275 |
+
results.extend([(meta_batch[i], out[i].view(-1).numpy()) for i in range(num_mentions)])
|
276 |
+
else:
|
277 |
+
results.extend(out.cpu().split(1, dim=0))
|
278 |
+
|
279 |
+
if (i + 1) % 10 == 0:
|
280 |
+
eta = (n - total) * (time.time() - start_time) / 60 / total
|
281 |
+
logger.info(f"Batch={i + 1}, Encoded mentions={total}, ETA={eta:.1f}m")
|
282 |
+
|
283 |
+
assert len(results) == n
|
284 |
+
logger.info(f"Total encoded mentions={n}")
|
285 |
+
if not is_prototype:
|
286 |
+
results = torch.cat(results, dim=0)
|
287 |
+
|
288 |
+
return results
|
289 |
+
|