|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import sys |
|
|
|
from pyserini.encode import JsonlRepresentationWriter, FaissRepresentationWriter, JsonlCollectionIterator |
|
from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AggretrieverDocumentEncoder, AutoDocumentEncoder |
|
from pyserini.encode import UniCoilDocumentEncoder |
|
|
|
|
|
encoder_class_map = { |
|
"dpr": DprDocumentEncoder, |
|
"tct_colbert": TctColBertDocumentEncoder, |
|
"aggretriever": AggretrieverDocumentEncoder, |
|
"ance": AnceDocumentEncoder, |
|
"sentence-transformers": AutoDocumentEncoder, |
|
"unicoil": UniCoilDocumentEncoder, |
|
"auto": AutoDocumentEncoder, |
|
} |
|
ALLOWED_POOLING_OPTS = ["cls","mean"] |
|
|
|
def init_encoder(encoder, encoder_class, device): |
|
_encoder_class = encoder_class |
|
|
|
|
|
if encoder_class is not None: |
|
encoder_class = encoder_class_map[encoder_class] |
|
else: |
|
|
|
|
|
for class_keyword in encoder_class_map: |
|
if class_keyword in encoder.lower(): |
|
encoder_class = encoder_class_map[class_keyword] |
|
break |
|
|
|
|
|
|
|
if encoder_class is None: |
|
encoder_class = AutoDocumentEncoder |
|
|
|
|
|
kwargs = dict(model_name=encoder, device=device) |
|
if (_encoder_class == "sentence-transformers") or ("sentence-transformers" in encoder): |
|
kwargs.update(dict(pooling='mean', l2_norm=True)) |
|
if (_encoder_class == "contriever") or ("contriever" in encoder): |
|
kwargs.update(dict(pooling='mean', l2_norm=False)) |
|
return encoder_class(**kwargs) |
|
|
|
|
|
def parse_args(parser, commands): |
|
|
|
split_argv = [[]] |
|
for c in sys.argv[1:]: |
|
if c in commands.choices: |
|
split_argv.append([c]) |
|
else: |
|
split_argv[-1].append(c) |
|
|
|
args = argparse.Namespace() |
|
for c in commands.choices: |
|
setattr(args, c, None) |
|
|
|
parser.parse_args(split_argv[0], namespace=args) |
|
for argv in split_argv[1:]: |
|
n = argparse.Namespace() |
|
setattr(args, argv[0], n) |
|
parser.parse_args(argv, namespace=n) |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
commands = parser.add_subparsers(title='sub-commands') |
|
input_parser = commands.add_parser('input') |
|
input_parser.add_argument('--corpus', type=str, |
|
help='directory that contains corpus files to be encoded, in jsonl format.', |
|
required=True) |
|
input_parser.add_argument('--fields', help='fields that contents in jsonl has (in order)', |
|
nargs='+', default=['text'], required=False) |
|
input_parser.add_argument('--docid-field', |
|
help='name of document id field name. If you have a custom id with a name other than "id", "_id" or "docid", then use this argument', |
|
default=None, required=False) |
|
input_parser.add_argument('--delimiter', help='delimiter for the fields', default='\n', required=False) |
|
input_parser.add_argument('--shard-id', type=int, help='shard-id 0-based', default=0, required=False) |
|
input_parser.add_argument('--shard-num', type=int, help='number of shards', default=1, required=False) |
|
|
|
output_parser = commands.add_parser('output') |
|
output_parser.add_argument('--embeddings', type=str, help='directory to store encoded corpus', required=True) |
|
output_parser.add_argument('--to-faiss', action='store_true', default=False) |
|
|
|
encoder_parser = commands.add_parser('encoder') |
|
encoder_parser.add_argument('--encoder', type=str, help='encoder name or path', required=True) |
|
encoder_parser.add_argument('--encoder-class', type=str, required=False, default=None, |
|
choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "auto"], |
|
help='which query encoder class to use. `default` would infer from the args.encoder') |
|
encoder_parser.add_argument('--fields', help='fields to encode', nargs='+', default=['text'], required=False) |
|
encoder_parser.add_argument('--batch-size', type=int, help='batch size', default=64, required=False) |
|
encoder_parser.add_argument('--max-length', type=int, help='max length', default=256, required=False) |
|
encoder_parser.add_argument('--dimension', type=int, help='dimension', default=768, required=False) |
|
encoder_parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]', |
|
default='cuda:0', required=False) |
|
encoder_parser.add_argument('--fp16', action='store_true', default=False) |
|
encoder_parser.add_argument('--add-sep', action='store_true', default=False) |
|
encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', required=False) |
|
|
|
args = parse_args(parser, commands) |
|
delimiter = args.input.delimiter.replace("\\n", "\n") |
|
|
|
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device) |
|
if type(encoder).__name__ == "AutoDocumentEncoder": |
|
if args.encoder.pooling in ALLOWED_POOLING_OPTS: |
|
encoder.pooling = args.encoder.pooling |
|
else: |
|
raise ValueError(f"Only allowed to use pooling types {ALLOWED_POOLING_OPTS}. You entered {args.encoder.pooling}") |
|
if args.output.to_faiss: |
|
embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension) |
|
else: |
|
embedding_writer = JsonlRepresentationWriter(args.output.embeddings) |
|
collection_iterator = JsonlCollectionIterator(args.input.corpus, args.input.fields, args.input.docid_field, delimiter) |
|
|
|
with embedding_writer: |
|
for batch_info in collection_iterator(args.encoder.batch_size, args.input.shard_id, args.input.shard_num): |
|
kwargs = { |
|
'texts': batch_info['text'], |
|
'titles': batch_info['title'] if 'title' in args.encoder.fields else None, |
|
'expands': batch_info['expand'] if 'expand' in args.encoder.fields else None, |
|
'fp16': args.encoder.fp16, |
|
'max_length': args.encoder.max_length, |
|
'add_sep': args.encoder.add_sep, |
|
} |
|
embeddings = encoder.encode(**kwargs) |
|
batch_info['vector'] = embeddings |
|
embedding_writer.write(batch_info, args.input.fields) |
|
|