Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import argparse | |
import sys | |
from pyserini.encode import JsonlRepresentationWriter, FaissRepresentationWriter, JsonlCollectionIterator | |
from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AutoDocumentEncoder | |
from pyserini.encode import UniCoilDocumentEncoder | |
encoder_class_map = { | |
"dpr": DprDocumentEncoder, | |
"tct_colbert": TctColBertDocumentEncoder, | |
"ance": AnceDocumentEncoder, | |
"sentence-transformers": AutoDocumentEncoder, | |
"unicoil": UniCoilDocumentEncoder, | |
"auto": AutoDocumentEncoder, | |
} | |
def init_encoder(encoder, encoder_class, device): | |
_encoder_class = encoder_class | |
# determine encoder_class | |
if encoder_class is not None: | |
encoder_class = encoder_class_map[encoder_class] | |
else: | |
# if any class keyword was matched in the given encoder name, | |
# use that encoder class | |
for class_keyword in encoder_class_map: | |
if class_keyword in encoder.lower(): | |
encoder_class = encoder_class_map[class_keyword] | |
break | |
# if none of the class keyword was matched, | |
# use the AutoDocumentEncoder | |
if encoder_class is None: | |
encoder_class = AutoDocumentEncoder | |
# prepare arguments to encoder class | |
kwargs = dict(model_name=encoder, device=device) | |
if (_encoder_class == "sentence-transformers") or ("sentence-transformers" in encoder): | |
kwargs.update(dict(pooling='mean', l2_norm=True)) | |
if (_encoder_class == "contriever") or ("contriever" in encoder): | |
kwargs.update(dict(pooling='mean', l2_norm=False)) | |
return encoder_class(**kwargs) | |
def parse_args(parser, commands): | |
# Divide argv by commands | |
split_argv = [[]] | |
for c in sys.argv[1:]: | |
if c in commands.choices: | |
split_argv.append([c]) | |
else: | |
split_argv[-1].append(c) | |
# Initialize namespace | |
args = argparse.Namespace() | |
for c in commands.choices: | |
setattr(args, c, None) | |
# Parse each command | |
parser.parse_args(split_argv[0], namespace=args) # Without command | |
for argv in split_argv[1:]: # Commands | |
n = argparse.Namespace() | |
setattr(args, argv[0], n) | |
parser.parse_args(argv, namespace=n) | |
return args | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
commands = parser.add_subparsers(title='sub-commands') | |
input_parser = commands.add_parser('input') | |
input_parser.add_argument('--corpus', type=str, | |
help='directory that contains corpus files to be encoded, in jsonl format.', | |
required=True) | |
input_parser.add_argument('--fields', help='fields that contents in jsonl has (in order)', | |
nargs='+', default=['text'], required=False) | |
input_parser.add_argument('--delimiter', help='delimiter for the fields', default='\n', required=False) | |
input_parser.add_argument('--shard-id', type=int, help='shard-id 0-based', default=0, required=False) | |
input_parser.add_argument('--shard-num', type=int, help='number of shards', default=1, required=False) | |
output_parser = commands.add_parser('output') | |
output_parser.add_argument('--embeddings', type=str, help='directory to store encoded corpus', required=True) | |
output_parser.add_argument('--to-faiss', action='store_true', default=False) | |
encoder_parser = commands.add_parser('encoder') | |
encoder_parser.add_argument('--encoder', type=str, help='encoder name or path', required=True) | |
encoder_parser.add_argument('--encoder-class', type=str, required=False, default=None, | |
choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "auto"], | |
help='which query encoder class to use. `default` would infer from the args.encoder') | |
encoder_parser.add_argument('--fields', help='fields to encode', nargs='+', default=['text'], required=False) | |
encoder_parser.add_argument('--batch-size', type=int, help='batch size', default=64, required=False) | |
encoder_parser.add_argument('--max-length', type=int, help='max length', default=256, required=False) | |
encoder_parser.add_argument('--dimension', type=int, help='dimension', default=768, required=False) | |
encoder_parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]', | |
default='cuda:0', required=False) | |
encoder_parser.add_argument('--fp16', action='store_true', default=False) | |
encoder_parser.add_argument('--add-sep', action='store_true', default=False) | |
args = parse_args(parser, commands) | |
delimiter = args.input.delimiter.replace("\\n", "\n") # argparse would add \ prior to the passed '\n\n' | |
encoder = init_encoder(args.encoder.encoder, args.encoder.encoder_class, device=args.encoder.device) | |
if args.output.to_faiss: | |
embedding_writer = FaissRepresentationWriter(args.output.embeddings, dimension=args.encoder.dimension) | |
else: | |
embedding_writer = JsonlRepresentationWriter(args.output.embeddings) | |
collection_iterator = JsonlCollectionIterator(args.input.corpus, args.input.fields, delimiter) | |
with embedding_writer: | |
for batch_info in collection_iterator(args.encoder.batch_size, args.input.shard_id, args.input.shard_num): | |
kwargs = { | |
'texts': batch_info['text'], | |
'titles': batch_info['title'] if 'title' in args.encoder.fields else None, | |
'expands': batch_info['expand'] if 'expand' in args.encoder.fields else None, | |
'fp16': args.encoder.fp16, | |
'max_length': args.encoder.max_length, | |
'add_sep': args.encoder.add_sep, | |
} | |
embeddings = encoder.encode(**kwargs) | |
batch_info['vector'] = embeddings | |
embedding_writer.write(batch_info, args.input.fields) | |