ONNX-Demo / pyserini /encode /__main__.py
ArthurChen189's picture
upload pyserini
62977bb
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import sys
from pyserini.encode import JsonlRepresentationWriter, FaissRepresentationWriter, JsonlCollectionIterator
from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AggretrieverDocumentEncoder, AutoDocumentEncoder
from pyserini.encode import UniCoilDocumentEncoder
encoder_class_map = {
"dpr": DprDocumentEncoder,
"tct_colbert": TctColBertDocumentEncoder,
"aggretriever": AggretrieverDocumentEncoder,
"ance": AnceDocumentEncoder,
"sentence-transformers": AutoDocumentEncoder,
"unicoil": UniCoilDocumentEncoder,
"auto": AutoDocumentEncoder,
}
ALLOWED_POOLING_OPTS = ["cls","mean"]
def init_encoder(encoder, encoder_class, device):
_encoder_class = encoder_class
# determine encoder_class
if encoder_class is not None:
encoder_class = encoder_class_map[encoder_class]
else:
# if any class keyword was matched in the given encoder name,
# use that encoder class
for class_keyword in encoder_class_map:
if class_keyword in encoder.lower():
encoder_class = encoder_class_map[class_keyword]
break
# if none of the class keyword was matched,
# use the AutoDocumentEncoder
if encoder_class is None:
encoder_class = AutoDocumentEncoder
# prepare arguments to encoder class
kwargs = dict(model_name=encoder, device=device)
if (_encoder_class == "sentence-transformers") or ("sentence-transformers" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=True))
if (_encoder_class == "contriever") or ("contriever" in encoder):
kwargs.update(dict(pooling='mean', l2_norm=False))
return encoder_class(**kwargs)
def parse_args(parser, commands):
# Divide argv by commands
split_argv = [[]]
for c in sys.argv[1:]:
if c in commands.choices:
split_argv.append([c])
else:
split_argv[-1].append(c)
# Initialize namespace
args = argparse.Namespace()
for c in commands.choices:
setattr(args, c, None)
# Parse each command
parser.parse_args(split_argv[0], namespace=args) # Without command
for argv in split_argv[1:]: # Commands
n = argparse.Namespace()
setattr(args, argv[0], n)
parser.parse_args(argv, namespace=n)
return args
if __name__ == '__main__':
parser = argparse.ArgumentParser()
commands = parser.add_subparsers(title='sub-commands')
input_parser = commands.add_parser('input')
input_parser.add_argument('--corpus', type=str,
help='directory that contains corpus files to be encoded, in jsonl format.',
required=True)
input_parser.add_argument('--fields', help='fields that contents in jsonl has (in order)',
nargs='+', default=['text'], required=False)
input_parser.add_argument('--docid-field',
help='name of document id field name. If you have a custom id with a name other than "id", "_id" or "docid", then use this argument',
default=None, required=False)
input_parser.add_argument('--delimiter', help='delimiter for the fields', default='\n', required=False)
input_parser.add_argument('--shard-id', type=int, help='shard-id 0-based', default=0, required=False)
input_parser.add_argument('--shard-num', type=int, help='number of shards', default=1, required=False)
output_parser = commands.add_parser('output')
output_parser.add_argument('--embeddings', type=str, help='directory to store encoded corpus', required=True)
output_parser.add_argument('--to-faiss', action='store_true', default=False)
encoder_parser = commands.add_parser('encoder')
encoder_parser.add_argument('--encoder', type=str, help='encoder name or path', required=True)
encoder_parser.add_argument('--encoder-class', type=str, required=False, default=None,
choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "auto"],
help='which query encoder class to use. `default` would infer from the args.encoder')
encoder_parser.add_argument('--fields', help='fields to encode', nargs='+', default=['text'], required=False)
encoder_parser.add_argument('--batch-size', type=int, help='batch size', default=64, required=False)
encoder_parser.add_argument('--max-length', type=int, help='max length', default=256, required=False)
encoder_parser.add_argument('--dimension', type=int, help='dimension', default=768, required=False)
encoder_parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]',
default='cuda:0', required=False)
encoder_parser.add_argument('--fp16', action='store_true', default=False)
encoder_parser.add_argument('--add-sep', action='store_true', default=False)
encoder_parser.add_argument('--pooling', type=str, default='cls', help='for auto classes, allow the ability to dictate pooling strategy', required=False)
args = parse_args(parser, commands)
delimiter = args.input.delimiter.replace("\\n", "\n") # argparse would add \ prior to the passed '\n\n'
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device)
if type(encoder).__name__ == "AutoDocumentEncoder":
if args.encoder.pooling in ALLOWED_POOLING_OPTS:
encoder.pooling = args.encoder.pooling
else:
raise ValueError(f"Only allowed to use pooling types {ALLOWED_POOLING_OPTS}. You entered {args.encoder.pooling}")
if args.output.to_faiss:
embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension)
else:
embedding_writer = JsonlRepresentationWriter(args.output.embeddings)
collection_iterator = JsonlCollectionIterator(args.input.corpus, args.input.fields, args.input.docid_field, delimiter)
with embedding_writer:
for batch_info in collection_iterator(args.encoder.batch_size, args.input.shard_id, args.input.shard_num):
kwargs = {
'texts': batch_info['text'],
'titles': batch_info['title'] if 'title' in args.encoder.fields else None,
'expands': batch_info['expand'] if 'expand' in args.encoder.fields else None,
'fp16': args.encoder.fp16,
'max_length': args.encoder.max_length,
'add_sep': args.encoder.add_sep,
}
embeddings = encoder.encode(**kwargs)
batch_info['vector'] = embeddings
embedding_writer.write(batch_info, args.input.fields)