|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
import faiss |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
|
|
class DocumentEncoder: |
|
def encode(self, texts, **kwargs): |
|
pass |
|
|
|
@staticmethod |
|
def _mean_pooling(last_hidden_state, attention_mask): |
|
token_embeddings = last_hidden_state |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
|
|
class QueryEncoder: |
|
def encode(self, text, **kwargs): |
|
pass |
|
|
|
|
|
class PcaEncoder: |
|
def __init__(self, encoder, pca_model_path): |
|
self.encoder = encoder |
|
self.pca_mat = faiss.read_VectorTransform(pca_model_path) |
|
|
|
def encode(self, text, **kwargs): |
|
if isinstance(text, str): |
|
embeddings = self.encoder.encode(text, **kwargs) |
|
embeddings = self.pca_mat.apply_py(np.array([embeddings])) |
|
embeddings = embeddings[0] |
|
else: |
|
embeddings = self.encoder.encode(text, **kwargs) |
|
embeddings = self.pca_mat.apply_py(embeddings) |
|
return embeddings |
|
|
|
|
|
class JsonlCollectionIterator: |
|
def __init__(self, collection_path: str, fields=None, docid_field=None, delimiter="\n"): |
|
if fields: |
|
self.fields = fields |
|
else: |
|
self.fields = ['text'] |
|
self.docid_field = docid_field |
|
self.delimiter = delimiter |
|
self.all_info = self._load(collection_path) |
|
self.size = len(self.all_info['id']) |
|
self.batch_size = 1 |
|
self.shard_id = 0 |
|
self.shard_num = 1 |
|
|
|
def __call__(self, batch_size=1, shard_id=0, shard_num=1): |
|
self.batch_size = batch_size |
|
self.shard_id = shard_id |
|
self.shard_num = shard_num |
|
return self |
|
|
|
def __iter__(self): |
|
total_len = self.size |
|
shard_size = int(total_len / self.shard_num) |
|
start_idx = self.shard_id * shard_size |
|
end_idx = min(start_idx + shard_size, total_len) |
|
if self.shard_id == self.shard_num - 1: |
|
end_idx = total_len |
|
to_yield = {} |
|
for idx in tqdm(range(start_idx, end_idx, self.batch_size)): |
|
for key in self.all_info: |
|
to_yield[key] = self.all_info[key][idx: min(idx + self.batch_size, end_idx)] |
|
yield to_yield |
|
|
|
def _parse_fields_from_info(self, info): |
|
""" |
|
:params info: dict, containing all fields as speicifed in self.fields either under |
|
the key of the field name or under the key of 'contents'. If under `contents`, this |
|
function will parse the input contents into each fields based the self.delimiter |
|
return: List, each corresponds to the value of self.fields |
|
""" |
|
n_fields = len(self.fields) |
|
|
|
|
|
if all([field in info for field in self.fields]): |
|
return [info[field].strip() for field in self.fields] |
|
|
|
assert "contents" in info, f"contents not found in info: {info}" |
|
contents = info['contents'] |
|
|
|
|
|
|
|
|
|
if contents.count(self.delimiter) == n_fields: |
|
|
|
if contents.endswith(self.delimiter): |
|
|
|
contents = contents[:-len(self.delimiter)] |
|
return [field.strip(" ") for field in contents.split(self.delimiter)] |
|
|
|
def _load(self, collection_path): |
|
filenames = [] |
|
if os.path.isfile(collection_path): |
|
filenames.append(collection_path) |
|
else: |
|
for filename in os.listdir(collection_path): |
|
filenames.append(os.path.join(collection_path, filename)) |
|
all_info = {field: [] for field in self.fields} |
|
all_info['id'] = [] |
|
for filename in filenames: |
|
with open(filename) as f: |
|
for line_i, line in tqdm(enumerate(f)): |
|
info = json.loads(line) |
|
if self.docid_field: |
|
_id = info.get(self.docid_field, None) |
|
else: |
|
_id = info.get('id', info.get('_id', info.get('docid', None))) |
|
if _id is None: |
|
raise ValueError(f"Cannot find f'`{self.docid_field if self.docid_field else '`id` or `_id` or `docid'}`' from {filename}.") |
|
all_info['id'].append(str(_id)) |
|
fields_info = self._parse_fields_from_info(info) |
|
if len(fields_info) != len(self.fields): |
|
raise ValueError( |
|
f"{len(fields_info)} fields are found at Line#{line_i} in file {filename}." \ |
|
f"{len(self.fields)} fields expected." \ |
|
f"Line content: {info['contents']}" |
|
) |
|
|
|
for i in range(len(fields_info)): |
|
all_info[self.fields[i]].append(fields_info[i]) |
|
return all_info |
|
|
|
|
|
class RepresentationWriter: |
|
def __enter__(self): |
|
pass |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
pass |
|
|
|
def write(self, batch_info, fields=None): |
|
pass |
|
|
|
|
|
class JsonlRepresentationWriter(RepresentationWriter): |
|
def __init__(self, dir_path): |
|
self.dir_path = dir_path |
|
self.filename = 'embeddings.jsonl' |
|
self.file = None |
|
|
|
def __enter__(self): |
|
if not os.path.exists(self.dir_path): |
|
os.makedirs(self.dir_path) |
|
self.file = open(os.path.join(self.dir_path, self.filename), 'w') |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.file.close() |
|
|
|
def write(self, batch_info, fields=None): |
|
for i in range(len(batch_info['id'])): |
|
contents = "\n".join([batch_info[key][i] for key in fields]) |
|
vector = batch_info['vector'][i] |
|
vector = vector.tolist() if isinstance(vector, np.ndarray) else vector |
|
self.file.write(json.dumps({'id': batch_info['id'][i], |
|
'contents': contents, |
|
'vector': vector}) + '\n') |
|
|
|
|
|
class FaissRepresentationWriter(RepresentationWriter): |
|
def __init__(self, dir_path, dimension=768): |
|
self.dir_path = dir_path |
|
self.index_name = 'index' |
|
self.id_file_name = 'docid' |
|
self.dimension = dimension |
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
self.id_file = None |
|
|
|
def __enter__(self): |
|
if not os.path.exists(self.dir_path): |
|
os.makedirs(self.dir_path) |
|
self.id_file = open(os.path.join(self.dir_path, self.id_file_name), 'w') |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.id_file.close() |
|
faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name)) |
|
|
|
def write(self, batch_info, fields=None): |
|
for id_ in batch_info['id']: |
|
self.id_file.write(f'{id_}\n') |
|
self.index.add(np.ascontiguousarray(batch_info['vector'])) |
|
|