shengz commited on
Commit
e3ef0b9
1 Parent(s): cfd4687

Add the example usage.

Browse files
usage/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Knowledge-Rich Self-Supervision (KRISS) for Biomedical Entity Linking
2
+
3
+ Usage code for the entity linking approach described in the following paper:
4
+ ```bibtex
5
+ @article{kriss,
6
+ author = {Sheng Zhang, Hao Cheng, Shikhar Vashishth, Cliff Wong, Jinfeng Xiao, Xiaodong Liu, Tristan Naumann, Jianfeng Gao, Hoifung Poon},
7
+ title = {Knowledge-Rich Self-Supervision for Biomedical Entity Linking},
8
+ year = {2021},
9
+ url = {https://arxiv.org/abs/2112.07887},
10
+ eprinttype = {arXiv},
11
+ eprint = {2112.07887},
12
+ }
13
+ ```
14
+ [https://arxiv.org/pdf/2112.07887.pdf](https://arxiv.org/pdf/2112.07887.pdf)
15
+
16
+ ## Usage of KRISS for Entity Linking
17
+
18
+ Here, we use the [MedMentions](https://github.com/chanzuckerberg/MedMentions) data to show you how to 1) generate prototype embeddings, and 2) run entity linking.
19
+
20
+ (We are currently unable to release the self-supervised mention examples, because they requires UMLS and PubMed licenses.)
21
+
22
+
23
+ ### 1. Create conda environment and install requirements
24
+ ```bash
25
+ conda create -n kriss -y python=3.8 && conda activate kriss
26
+ pip install -r requirements.txt
27
+ ```
28
+
29
+ ### 2. Download the MedMentions dataset
30
+
31
+ ```bash
32
+ git clone https://github.com/chanzuckerberg/MedMentions.git
33
+ ```
34
+
35
+ ### 3. Generate prototype embeddings
36
+ ```bash
37
+ python generate_prototypes.py
38
+ ```
39
+
40
+ ### 4. Run entity linking
41
+ ```bash
42
+ python run_entity_linking.py
43
+ ```
usage/conf/generate_prototypes.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ model_name_or_path: microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL
2
+ train_data:
3
+ _target_: utils.MedMentionsDataset
4
+ dataset_path: MedMentions/full/data/
5
+ split: train
6
+ batch_size: 256
7
+ max_length: 64
8
+ output_prototypes: prototypes/embeddings
9
+ output_name_cuis: prototypes/name_cuis
usage/conf/run_linking.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to pretrained model and tokenizer
2
+ model_name_or_path: microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL
3
+
4
+ test_data:
5
+ _target_: utils.MedMentionsDataset
6
+ dataset_path: MedMentions/full/data/
7
+ split: test
8
+
9
+ # paths to encoded data
10
+ encoded_files: [
11
+ prototypes/embeddings
12
+ ]
13
+
14
+ encoded_umls_files: []
15
+
16
+ entity_list_ids:
17
+
18
+ entity_list_names: prototypes/name_cuis
19
+
20
+ index_path:
21
+
22
+ seed: 12345
23
+ batch_size: 256
24
+ max_length: 64
25
+ num_retrievals: 100
26
+ top_ks: [1, 5, 50, 100]
usage/generate_prototypes.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Command line tool that produces embeddings for a large set of entity mentions
6
+ based on the pretrained mention encoder.
7
+ """
8
+ import logging
9
+ import os
10
+ import pathlib
11
+ import pickle
12
+
13
+ import hydra
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
16
+
17
+ from utils import generate_vectors
18
+
19
+
20
+ # Setup logger
21
+ logger = logging.getLogger()
22
+ logger.setLevel(logging.INFO)
23
+ log_formatter = logging.Formatter(
24
+ "[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s"
25
+ )
26
+ console = logging.StreamHandler()
27
+ console.setFormatter(log_formatter)
28
+ logger.addHandler(console)
29
+
30
+
31
+ @hydra.main(config_path="conf", config_name="generate_prototypes", version_base=None)
32
+ def main(cfg: DictConfig):
33
+ logger.info("Configuration:")
34
+ logger.info("%s", OmegaConf.to_yaml(cfg))
35
+
36
+ config = AutoConfig.from_pretrained(cfg.model_name_or_path)
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ cfg.model_name_or_path,
39
+ use_fast=True,
40
+ )
41
+ encoder = AutoModel.from_pretrained(
42
+ cfg.model_name_or_path,
43
+ config=config
44
+ )
45
+ encoder.cuda()
46
+ encoder.eval()
47
+
48
+ ds = hydra.utils.instantiate(cfg.train_data)
49
+ data = generate_vectors(encoder, tokenizer, ds, cfg.batch_size, cfg.max_length, is_prototype=True)
50
+ pathlib.Path(os.path.dirname(cfg.output_prototypes)).mkdir(parents=True, exist_ok=True)
51
+ logger.info("Writing results to %s" % cfg.output_prototypes)
52
+ with open(cfg.output_prototypes, mode="wb") as f:
53
+ pickle.dump(data, f)
54
+ with open(cfg.output_name_cuis, 'w') as f:
55
+ for name, cuis in ds.name_to_cuis.items():
56
+ f.write('|'.join(cuis) + '||' + name + '\n')
57
+ logger.info("Total data processed %d. Written to %s", len(data), cfg.output_prototypes)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
usage/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.17.0
2
+ torch==1.11
3
+ hydra-core==1.2.0
4
+ faiss-gpu==1.7.0
usage/run_entity_linking.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Run entity linking
6
+ """
7
+
8
+ import os
9
+ import glob
10
+ import logging
11
+ import pathlib
12
+ import pickle
13
+ import time
14
+ import math
15
+ import multiprocessing
16
+ from typing import List, Tuple, Dict, Iterator, Set
17
+ from functools import partial
18
+ from multiprocessing.dummy import Pool
19
+
20
+ import hydra
21
+ import numpy as np
22
+ import torch
23
+ from omegaconf import DictConfig, OmegaConf
24
+ from torch import Tensor as T
25
+ from torch import nn
26
+ import faiss
27
+
28
+ from transformers import (
29
+ set_seed,
30
+ AutoConfig,
31
+ AutoTokenizer,
32
+ AutoModel,
33
+ PreTrainedTokenizer,
34
+ )
35
+ from utils import generate_vectors
36
+
37
+
38
+ # Setup logger
39
+ logger = logging.getLogger()
40
+ logger.setLevel(logging.INFO)
41
+ log_formatter = logging.Formatter(
42
+ "[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s"
43
+ )
44
+ console = logging.StreamHandler()
45
+ console.setFormatter(log_formatter)
46
+ logger.addHandler(console)
47
+
48
+
49
+ class DenseIndexer(object):
50
+ def __init__(self, buffer_size: int = 50000):
51
+ self.buffer_size = buffer_size
52
+ self.index_id_to_db_id = []
53
+ self.index = None
54
+
55
+ def init_index(self, vector_sz: int):
56
+ raise NotImplementedError
57
+
58
+ def index_data(self, data: List[Tuple[object, np.array]]):
59
+ raise NotImplementedError
60
+
61
+ def get_index_name(self):
62
+ raise NotImplementedError
63
+
64
+ def search_knn(
65
+ self, query_vectors: np.array, top_docs: int
66
+ ) -> List[Tuple[List[object], List[float]]]:
67
+ raise NotImplementedError
68
+
69
+ def serialize(self, file: str):
70
+ logger.info("Serializing index to %s", file)
71
+
72
+ if os.path.isdir(file):
73
+ index_file = os.path.join(file, "index.dpr")
74
+ meta_file = os.path.join(file, "index_meta.dpr")
75
+ else:
76
+ index_file = file + ".index.dpr"
77
+ meta_file = file + ".index_meta.dpr"
78
+
79
+ faiss.write_index(self.index, index_file)
80
+ with open(meta_file, mode="wb") as f:
81
+ pickle.dump(self.index_id_to_db_id, f)
82
+
83
+ def get_files(self, path: str):
84
+ if os.path.isdir(path):
85
+ index_file = os.path.join(path, "index.dpr")
86
+ meta_file = os.path.join(path, "index_meta.dpr")
87
+ else:
88
+ index_file = path + ".index.dpr"
89
+ meta_file = path + ".index_meta.dpr"
90
+ return index_file, meta_file
91
+
92
+ def index_exists(self, path: str):
93
+ index_file, meta_file = self.get_files(path)
94
+ return os.path.isfile(index_file) and os.path.isfile(meta_file)
95
+
96
+ def deserialize(self, path: str):
97
+ logger.info("Loading index from %s", path)
98
+ index_file, meta_file = self.get_files(path)
99
+
100
+ self.index = faiss.read_index(index_file)
101
+ logger.info(
102
+ "Loaded index of type %s and size %d", type(self.index), self.index.ntotal
103
+ )
104
+
105
+ with open(meta_file, "rb") as reader:
106
+ self.index_id_to_db_id = pickle.load(reader)
107
+ assert (
108
+ len(self.index_id_to_db_id) == self.index.ntotal
109
+ ), "Deserialized index_id_to_db_id should match faiss index size"
110
+
111
+ def _update_id_mapping(self, db_ids: List) -> int:
112
+ self.index_id_to_db_id.extend(db_ids)
113
+ return len(self.index_id_to_db_id)
114
+
115
+
116
+ class DenseFlatIndexer(DenseIndexer):
117
+ def __init__(self, buffer_size: int = 50000):
118
+ super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
119
+
120
+ def init_index(self, vector_sz: int):
121
+ self.index = faiss.IndexFlatIP(vector_sz)
122
+
123
+ def index_data(self, data: List[Tuple[object, np.array]]):
124
+ n = len(data)
125
+ # indexing in batches is beneficial for many faiss index types
126
+ for i in range(0, n, self.buffer_size):
127
+ db_ids = [t[0] for t in data[i : i + self.buffer_size]]
128
+ vectors = [
129
+ np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]
130
+ ]
131
+ vectors = np.concatenate(vectors, axis=0)
132
+ total_data = self._update_id_mapping(db_ids)
133
+ self.index.add(vectors)
134
+ logger.info("data indexed %d", total_data)
135
+
136
+ indexed_cnt = len(self.index_id_to_db_id)
137
+ logger.info("Total data indexed %d", indexed_cnt)
138
+
139
+ def search_knn(
140
+ self, query_vectors: np.array, top_docs: int, batch_size: int = 4096,
141
+ ) -> List[Tuple[List[object], List[float]]]:
142
+ num_queries = query_vectors.shape[0]
143
+ scores, indexes = [], []
144
+ for start in range(0, num_queries, batch_size):
145
+ logger.info(f"Searched {start} queries.")
146
+ batch_vectors = query_vectors[start:start + batch_size]
147
+ batch_scores, batch_indexes = self.index.search(batch_vectors, top_docs)
148
+ scores.extend(batch_scores)
149
+ indexes.extend(batch_indexes)
150
+ # convert to external ids
151
+ db_ids = [
152
+ [self.index_id_to_db_id[i] for i in query_top_idxs]
153
+ for query_top_idxs in indexes
154
+ ]
155
+ result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
156
+ return result
157
+
158
+ def get_index_name(self):
159
+ return "flat_index"
160
+
161
+
162
+ def load_umls_data(files_patterns: List[str], candidate_ids: Dict = None) -> Dict:
163
+ input_paths = []
164
+ for pattern in files_patterns:
165
+ pattern_files = glob.glob(pattern)
166
+ input_paths.extend(pattern_files)
167
+ umls_data = {}
168
+ for file in sorted(input_paths):
169
+ logger.info("Reading encoded UMLS data from file %s", file)
170
+ with open(file, "rb") as reader:
171
+ for meta, vec in pickle.load(reader):
172
+ assert len(meta['cuis']) == 1, breakpoint()
173
+ cui = meta['cuis'][0]
174
+ if candidate_ids and cui not in candidate_ids:
175
+ continue
176
+ umls_data[cui] = (meta, vec)
177
+ logger.info(f"Loaded UMLS data = {len(umls_data)}.")
178
+ return umls_data
179
+
180
+
181
+ def iterate_encoded_files(
182
+ vector_files: list,
183
+ candidate_ids: Set = None,
184
+ umls_data: Dict = None,
185
+ )-> Iterator:
186
+ logger.info("Loading encoded prototype embeddings...")
187
+ proto_data = {}
188
+ for file in vector_files:
189
+ logger.info("Reading file %s", file)
190
+ with open(file, "rb") as reader:
191
+ for meta, vec in pickle.load(reader):
192
+ cuis = meta['cuis']
193
+ if candidate_ids and all(c not in candidate_ids for c in cuis):
194
+ continue
195
+ for cui in cuis:
196
+ proto_data.setdefault(cui, []).append((meta, vec))
197
+ # Concatenate prototype embs with additional knowledge embs from UMLS.
198
+ if umls_data is not None:
199
+ for cui, (meta, vec) in umls_data.items():
200
+ if cui in proto_data:
201
+ for _, _vec in proto_data.pop(cui):
202
+ extended_vec = np.concatenate((vec, _vec), axis=0)
203
+ yield (meta, extended_vec)
204
+ else:
205
+ extended_vec = np.concatenate((vec, np.zeros_like(vec)), axis=0)
206
+ yield (meta, extended_vec)
207
+ for cui in list(proto_data.keys()):
208
+ for meta, vec in proto_data.pop(cui):
209
+ extended_vec = np.concatenate((np.zeros_like(vec), vec), axis=0)
210
+ yield (meta, extended_vec)
211
+ assert len(proto_data) == 0
212
+
213
+
214
+ class DenseRetriever:
215
+ def __init__(
216
+ self,
217
+ encoder: nn.Module,
218
+ tokenizer: PreTrainedTokenizer,
219
+ batch_size: int,
220
+ max_length: int,
221
+ ):
222
+ self.encoder = encoder
223
+ self.tokenizer = tokenizer
224
+ self.batch_size = batch_size
225
+ self.max_length = max_length
226
+
227
+ def generate_mention_vectors(self, ds: torch.utils.data.Dataset) -> T:
228
+ self.encoder.eval()
229
+ return generate_vectors(
230
+ encoder=self.encoder,
231
+ tokenizer=self.tokenizer,
232
+ dataset=ds,
233
+ batch_size=self.batch_size,
234
+ max_length=self.max_length,
235
+ )
236
+
237
+
238
+ class FaissRetriever(DenseRetriever):
239
+ """
240
+ Does entity retrieving over the provided index and encoder.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ encoder: nn.Module,
246
+ tokenizer: PreTrainedTokenizer,
247
+ batch_size: int,
248
+ max_length: int,
249
+ index: DenseIndexer,
250
+ ):
251
+ super().__init__(encoder, tokenizer, batch_size, max_length)
252
+ self.index = index
253
+
254
+ def index_encoded_data(
255
+ self,
256
+ vector_files: List[str],
257
+ buffer_size: int,
258
+ candidate_ids: Set = None,
259
+ umls_data: Dict = None,
260
+ ):
261
+ """
262
+ Indexes encoded data takes form a list of files
263
+ :param vector_files: a list of files
264
+ :param buffer_size: size of a buffer to send for the indexing at once
265
+ :return:
266
+ """
267
+ buffer = []
268
+ for i, item in enumerate(
269
+ iterate_encoded_files(vector_files, candidate_ids, umls_data)
270
+ ):
271
+ buffer.append(item)
272
+ if 0 < buffer_size == len(buffer):
273
+ self.index.index_data(buffer)
274
+ buffer = []
275
+ self.index.index_data(buffer)
276
+ logger.info("Data indexing completed.")
277
+
278
+ def get_top_hits(
279
+ self, mention_vectors: np.array, top_k: int = 100
280
+ ) -> List[Tuple[List[object], List[float]]]:
281
+ """
282
+ Does the retrieval of the best matching given the mention vectors batch
283
+ """
284
+ time0 = time.time()
285
+ search = partial(
286
+ self.index.search_knn,
287
+ top_docs=top_k,
288
+ )
289
+ results = []
290
+ num_processes = multiprocessing.cpu_count()
291
+ shard_size = math.ceil(mention_vectors.shape[0] / num_processes)
292
+ shards = []
293
+ for i in range(0, mention_vectors.shape[0], shard_size):
294
+ shards.append(mention_vectors[i:i + shard_size])
295
+ with Pool(processes=num_processes) as pool:
296
+ it = pool.map(search, shards)
297
+ for ret in it:
298
+ results += ret
299
+ # results = self.index.search_knn(mention_vectors, top_k)
300
+ logger.info("index search time: %f sec.", time.time() - time0)
301
+ self.index = None
302
+ return results
303
+
304
+
305
+ def hit(pred: List[str], gold: List[str]) -> bool:
306
+ return all(p in gold for p in pred)
307
+
308
+
309
+ def dedup_ids(ids: List[Dict]) -> List[Dict]:
310
+ deduped_ids = []
311
+ seen_cuis = set()
312
+ for d in ids:
313
+ if all(cui in seen_cuis for cui in d['cuis']):
314
+ continue
315
+ seen_cuis.update(d['cuis'])
316
+ deduped_ids.append(d)
317
+ return deduped_ids
318
+
319
+
320
+ def evaluate(
321
+ ds: torch.utils.data.Dataset,
322
+ result_ent_ids: List[Tuple[List[object], List[float]]],
323
+ lookup_table: str,
324
+ top_ks: List[int] = (1, 5, 50, 100),
325
+ ) -> List[Dict]:
326
+ lut = {}
327
+ with open(lookup_table, encoding='utf-8') as f:
328
+ for ln in f:
329
+ cuis, name = ln.strip().split('||')
330
+ cuis = cuis.split('|')
331
+ lut[name] = cuis
332
+
333
+ n = len(ds)
334
+ top_k_hits = {top_k: 0 for top_k in top_ks}
335
+ for i in range(len(result_ent_ids)):
336
+ d = ds[i]
337
+ ids, _ = result_ent_ids[i]
338
+ ids = dedup_ids(ids)
339
+ ids = ids[:max(top_ks)]
340
+ candidates = [
341
+ {'cuis': eid['cuis'], 'hit': int(hit(pred=eid['cuis'], gold=d.cuis))}
342
+ for eid in ids
343
+ ]
344
+ lut_cuis = lut.get(d.mention, [])
345
+ if len(lut_cuis) == 1:
346
+ # If the mention only has one ID in the look up table,
347
+ # we use the ID as the top prediction.
348
+ candidates.insert(
349
+ 0,
350
+ {'cuis': lut_cuis, 'hit': int(hit(pred=lut_cuis, gold=d.cuis))}
351
+ )
352
+ for top_k in top_k_hits:
353
+ if any(c['hit'] for c in candidates[:top_k]):
354
+ top_k_hits[top_k] += 1
355
+
356
+ top_k_acc = {top_k: v / n for top_k, v in top_k_hits.items()}
357
+ logger.info("Top-k accuracy %s", top_k_acc)
358
+
359
+
360
+ @hydra.main(config_path="conf", config_name="run_linking", version_base=None)
361
+ def main(cfg: DictConfig):
362
+ set_seed(cfg.seed)
363
+
364
+ logger.info("Configuration:")
365
+ logger.info("%s", OmegaConf.to_yaml(cfg))
366
+
367
+ # Load pretrained.
368
+ config = AutoConfig.from_pretrained(cfg.model_name_or_path)
369
+ tokenizer = AutoTokenizer.from_pretrained(
370
+ cfg.model_name_or_path,
371
+ use_fast=True,
372
+ )
373
+ encoder = AutoModel.from_pretrained(
374
+ cfg.model_name_or_path,
375
+ config=config
376
+ )
377
+ encoder.cuda()
378
+ encoder.eval()
379
+ vector_size = config.hidden_size
380
+ logger.info("Encoder vector_size=%d", vector_size)
381
+
382
+ # Load test data.
383
+ ds = hydra.utils.instantiate(cfg.test_data)
384
+
385
+ # Init indexer.
386
+ index = DenseFlatIndexer()
387
+ index_buffer_sz = index.buffer_size
388
+ index.init_index(vector_size * 2)
389
+
390
+ # candidate ids
391
+ candidate_ids = None
392
+ if cfg.entity_list_ids:
393
+ with open(cfg.entity_list_ids, encoding='utf-8') as f:
394
+ candidate_ids = set(f.read().split('\n'))
395
+
396
+ # Start indexing
397
+ input_paths = []
398
+ for pattern in cfg.encoded_files:
399
+ pattern_files = glob.glob(pattern)
400
+ input_paths.extend(pattern_files)
401
+ input_paths = sorted(set(input_paths))
402
+
403
+ retriever = FaissRetriever(
404
+ encoder, tokenizer, cfg.batch_size, cfg.max_length, index)
405
+ mentions_tensor = retriever.generate_mention_vectors(ds)
406
+
407
+ # Load UMLS knowledge
408
+ umls_data = None
409
+ if cfg.encoded_umls_files:
410
+ umls_data = load_umls_data(cfg.encoded_umls_files, candidate_ids)
411
+
412
+ index_path = cfg.index_path
413
+ if index_path and index.index_exists(index_path):
414
+ logger.info("Index path: %s", index_path)
415
+ retriever.index.deserialize(index_path)
416
+ else:
417
+ logger.info("Indexing encoded data from files: %s", input_paths)
418
+ retriever.index_encoded_data(
419
+ vector_files=input_paths,
420
+ buffer_size=index_buffer_sz,
421
+ candidate_ids=candidate_ids,
422
+ umls_data=umls_data,
423
+ )
424
+ if index_path:
425
+ pathlib.Path(os.path.dirname(index_path)).mkdir(
426
+ parents=True, exist_ok=True)
427
+ retriever.index.serialize(index_path)
428
+
429
+ # Encode test data.
430
+ mentions_tensor = torch.cat([mentions_tensor, mentions_tensor], dim=1)
431
+
432
+ # To get k different entities, we retrieve 32 * k mentions and then dedup.
433
+ top_ids_and_scores = retriever.get_top_hits(
434
+ mentions_tensor.numpy(), cfg.num_retrievals * 32)
435
+
436
+ evaluate(ds, top_ids_and_scores, cfg.entity_list_names)
437
+
438
+
439
+ if __name__ == "__main__":
440
+ main()
usage/utils.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+
5
+ from typing import List, Dict
6
+ import os
7
+ import time
8
+ import logging
9
+ import json
10
+ import gzip
11
+ from dataclasses import dataclass, field
12
+
13
+ import torch
14
+ from torch import Tensor as T
15
+ from transformers import PreTrainedTokenizer
16
+
17
+
18
+ logger = logging.getLogger()
19
+
20
+
21
+ @dataclass
22
+ class Mention:
23
+ cui: str
24
+ start: int
25
+ end: int
26
+ text: str
27
+ types: str
28
+
29
+
30
+ @dataclass
31
+ class ContextualMention:
32
+ mention: str
33
+ cuis: List[str]
34
+ ctx_l: str
35
+ ctx_r: str
36
+
37
+ def to_tensor(self, tokenizer: PreTrainedTokenizer, max_length: int) -> T:
38
+ ctx_l_ids = tokenizer.encode(
39
+ text=self.ctx_l,
40
+ add_special_tokens=False,
41
+ max_length=max_length,
42
+ truncation=True,
43
+ )
44
+ ctx_r_ids = tokenizer.encode(
45
+ text=self.ctx_r,
46
+ add_special_tokens=False,
47
+ max_length=max_length,
48
+ truncation=True,
49
+ )
50
+ mention_ids = tokenizer.encode(
51
+ text=self.mention,
52
+ add_special_tokens=False,
53
+ max_length=max_length,
54
+ truncation=True,
55
+ )
56
+
57
+ # Concatenate context and mention to the max length.
58
+ token_ids = tokenizer.convert_tokens_to_ids(['<ENT>']) + mention_ids \
59
+ + tokenizer.convert_tokens_to_ids(['</ENT>'])
60
+ max_ctx_len = max_length - len(token_ids) - 2 # Exclude [CLS] and [SEP]
61
+ max_ctx_l_len = max_ctx_len // 2
62
+ max_ctx_r_len = max_ctx_len - max_ctx_l_len
63
+ if len(ctx_l_ids) < max_ctx_l_len and len(ctx_r_ids) < max_ctx_r_len:
64
+ token_ids = ctx_l_ids + token_ids + ctx_r_ids
65
+ elif len(ctx_l_ids) >= max_ctx_l_len and len(ctx_r_ids) >= max_ctx_r_len:
66
+ token_ids = ctx_l_ids[-max_ctx_l_len:] + token_ids \
67
+ + ctx_r_ids[:max_ctx_r_len]
68
+ elif len(ctx_l_ids) >= max_ctx_l_len:
69
+ ctx_l_len = max_ctx_len - len(ctx_r_ids)
70
+ token_ids = ctx_l_ids[-ctx_l_len:] + token_ids + ctx_r_ids
71
+ else:
72
+ ctx_r_len = max_ctx_len - len(ctx_l_ids)
73
+ token_ids = ctx_l_ids + token_ids + ctx_r_ids[:ctx_r_len]
74
+
75
+ token_ids = [tokenizer.cls_token_id] + token_ids
76
+
77
+ # The above snippet doesn't guarantee the max length limit.
78
+ token_ids = token_ids[:max_length - 1] + [tokenizer.sep_token_id]
79
+
80
+ if len(token_ids) < max_length:
81
+ token_ids = token_ids + [tokenizer.pad_token_id] * (max_length - len(token_ids))
82
+
83
+ return torch.tensor(token_ids)
84
+
85
+
86
+ @dataclass
87
+ class Document:
88
+ id: str = None
89
+ title: str = None
90
+ abstract: str = None
91
+ mentions: List[Mention] = field(default_factory=list)
92
+
93
+ def concatenate_text(self) -> str:
94
+ return ' '.join([self.title, self.abstract])
95
+
96
+ @classmethod
97
+ def from_PubTator(cls, path: str, split_path_prefix: str) -> Dict[str, List]:
98
+ docs = []
99
+ with gzip.open(path, 'rb') as f:
100
+ for b in f.read().decode().strip().split('\n\n'):
101
+ d = cls()
102
+ s = ''
103
+ for i, ln in enumerate(b.split('\n')):
104
+ if i == 0:
105
+ id, type, text = ln.strip().split('|', 2)
106
+ assert type == 't'
107
+ d.id, d.title = id, text
108
+ elif i == 1:
109
+ id, type, text = ln.strip().split('|', 2)
110
+ assert type == 'a'
111
+ assert d.id == id
112
+ d.abstract = text
113
+ s = d.concatenate_text()
114
+ else:
115
+ items = ln.strip().split('\t')
116
+ assert d.id == items[0]
117
+ cui = items[5].split('UMLS:')[-1]
118
+ assert len(cui) == 8, breakpoint()
119
+ m = Mention(
120
+ cui=cui,
121
+ start=int(items[1]),
122
+ end=int(items[2]),
123
+ text=items[3],
124
+ types=items[4].split(',')
125
+ )
126
+ assert m.text == s[m.start: m.end]
127
+ d.mentions.append(m)
128
+ docs.append(d)
129
+ dataset = split_dataset(docs, split_path_prefix)
130
+ print_dataset_stats(dataset)
131
+ return dataset
132
+
133
+ def to_contextual_mentions(self, max_length: int = 64) -> List[ContextualMention]:
134
+ text = self.concatenate_text()
135
+ mentions = []
136
+ for m in self.mentions:
137
+ assert m.text == text[m.start:m.end]
138
+ # Context
139
+ ctx_l, ctx_r = text[:m.start].strip().split(), text[m.end:].strip().split()
140
+ ctx_l, ctx_r = ' '.join(ctx_l[-max_length:]), ' '.join(ctx_r[:max_length])
141
+ cm = ContextualMention(
142
+ mention=m.text,
143
+ cuis=[m.cui],
144
+ ctx_l=ctx_l,
145
+ ctx_r=ctx_r,
146
+ )
147
+ mentions.append(cm)
148
+ return mentions
149
+
150
+
151
+ def split_dataset(docs: List, split_path_prefix: str) -> Dict[str, List]:
152
+ split_kv = {'train': 'trng', 'dev': 'dev', 'test': 'test'}
153
+ id_to_split = {}
154
+ dataset = {}
155
+ for k, v in split_kv.items():
156
+ dataset[k] = []
157
+ path = split_path_prefix + v + '.txt'
158
+ for i in open(path, encoding='utf-8').read().strip().split('\n'):
159
+ assert i not in id_to_split, breakpoint()
160
+ id_to_split[i] = k
161
+ for doc in docs:
162
+ split = id_to_split[doc.id]
163
+ dataset[split].append(doc)
164
+ return dataset
165
+
166
+
167
+ def print_dataset_stats(dataset: Dict[str, List[Document]]) -> None:
168
+ all_docs = []
169
+ for v in dataset.values():
170
+ all_docs.extend(v)
171
+ for split, docs in {'all': all_docs, **dataset}.items():
172
+ logger.info(f"***** {split} *****")
173
+ logger.info(f"Documents: {len(docs)}")
174
+ logger.info(f"Mentions: {sum(len(d.mentions) for d in docs)}")
175
+ cuis = set()
176
+ for d in docs:
177
+ for m in d.mentions:
178
+ cuis.add(m.cui)
179
+ logger.info(f"Mentioned concepts: {len(cuis)}")
180
+
181
+
182
+ class MedMentionsDataset(torch.utils.data.Dataset):
183
+
184
+ def __init__(self, dataset_path: str, split: str) -> None:
185
+ super().__init__()
186
+ self.dataset_path = dataset_path
187
+ self.docs = Document.from_PubTator(
188
+ path=os.path.join(self.dataset_path, 'corpus_pubtator.txt.gz'),
189
+ split_path_prefix=os.path.join(self.dataset_path, 'corpus_pubtator_pmids_')
190
+ )[split]
191
+ self.mentions = []
192
+ self.name_to_cuis = {}
193
+ self._post_init()
194
+
195
+ def _post_init(self):
196
+ for d in self.docs:
197
+ self.mentions.extend(d.to_contextual_mentions())
198
+ for m in self.mentions:
199
+ if m.mention not in self.name_to_cuis:
200
+ self.name_to_cuis[m.mention] = set()
201
+ self.name_to_cuis[m.mention].update(m.cuis)
202
+
203
+ def __getitem__(self, index: int) -> ContextualMention:
204
+ return self.mentions[index]
205
+
206
+ def __len__(self) -> int:
207
+ return len(self.mentions)
208
+
209
+
210
+ class PreprocessedDataset(torch.utils.data.Dataset):
211
+
212
+ def __init__(self, dataset_path: str) -> None:
213
+ super().__init__()
214
+ self.file = dataset_path
215
+ self.data = []
216
+ self.load_data()
217
+
218
+ def load_data(self) -> None:
219
+ with open(self.file, encoding='utf-8') as f:
220
+ logger.info("Reading file %s" % self.file)
221
+ for ln in f:
222
+ if ln.strip():
223
+ self.data.append(json.loads(ln))
224
+ logger.info("Loaded data size: {}".format(len(self.data)))
225
+
226
+ def __getitem__(self, index: int) -> ContextualMention:
227
+ d = self.data[index]
228
+ return ContextualMention(
229
+ ctx_l=d['context_left'],
230
+ ctx_r=d['context_right'],
231
+ mention=d['mention'],
232
+ cuis=d['cuis'],
233
+ )
234
+
235
+ def __len__(self) -> int:
236
+ return len(self.data)
237
+
238
+
239
+ def generate_vectors(
240
+ encoder: torch.nn.Module,
241
+ tokenizer: PreTrainedTokenizer,
242
+ dataset: torch.utils.data.Dataset,
243
+ batch_size: int,
244
+ max_length: int,
245
+ is_prototype: bool = False,
246
+ ):
247
+ n = len(dataset)
248
+ total = 0
249
+ results = []
250
+ start_time = time.time()
251
+ logger.info("Start encoding...")
252
+ for i, batch_start in enumerate(range(0, n, batch_size)):
253
+ batch = [dataset[i] for i in range(batch_start, min(n, batch_start + batch_size))]
254
+ batch_token_tensors = [m.to_tensor(tokenizer, max_length) for m in batch]
255
+
256
+ ids_batch = torch.stack(batch_token_tensors, dim=0).cuda()
257
+ seg_batch = torch.zeros_like(ids_batch)
258
+ attn_mask = (ids_batch != tokenizer.pad_token_id)
259
+
260
+ with torch.inference_mode():
261
+ out = encoder(
262
+ input_ids=ids_batch,
263
+ token_type_ids=seg_batch,
264
+ attention_mask=attn_mask
265
+ )
266
+ out = out[0][:, 0, :]
267
+ out = out.cpu()
268
+
269
+ num_mentions = out.size(0)
270
+ total += num_mentions
271
+
272
+ if is_prototype:
273
+ meta_batch = [{'cuis': m.cuis} for m in batch]
274
+ assert len(meta_batch) == num_mentions
275
+ results.extend([(meta_batch[i], out[i].view(-1).numpy()) for i in range(num_mentions)])
276
+ else:
277
+ results.extend(out.cpu().split(1, dim=0))
278
+
279
+ if (i + 1) % 10 == 0:
280
+ eta = (n - total) * (time.time() - start_time) / 60 / total
281
+ logger.info(f"Batch={i + 1}, Encoded mentions={total}, ETA={eta:.1f}m")
282
+
283
+ assert len(results) == n
284
+ logger.info(f"Total encoded mentions={n}")
285
+ if not is_prototype:
286
+ results = torch.cat(results, dim=0)
287
+
288
+ return results
289
+