Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import json | |
import argparse | |
from kilt import retrieval | |
from kilt import kilt_utils as utils | |
from kilt.retrievers import DrQA_tfidf | |
def execute( | |
logger, test_config_json, retriever, log_directory, model_name, output_folder, topk | |
): | |
# run evaluation | |
retrieval.run( | |
test_config_json, retriever, model_name, logger, output_folder=output_folder, topk=topk | |
) | |
def main(args): | |
# load configs | |
with open(args.config, "r") as fin: | |
test_config_json = json.load(fin) | |
# create a new directory to log and store results | |
log_directory = utils.create_logdir_with_timestamp("logs") | |
logger = None | |
logger = utils.init_logging(log_directory, args.name, logger) | |
logger.info("loading {} ...".format(args.name)) | |
retriever = DrQA_tfidf.DrQA(args.name, args.index_dir, args.threads) | |
execute( | |
logger, | |
test_config_json, | |
retriever, | |
log_directory, | |
args.name, | |
f"{args.output_dir}/{args.name}", | |
args.topk | |
) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Execute retrieval on KILT') | |
parser.add_argument('--config', required=False, default="kilt/configs/dev_data.json", | |
help='Config json containing which tasks to run') | |
parser.add_argument('--index_dir', required=False, default="models/kilt_db_simple.npz", | |
help='Path to the DRQA index directory') | |
parser.add_argument('--output_dir', required=False, default="outputs", | |
help='Output directory') | |
parser.add_argument('--topk', required=False, type=int, default=100, | |
help='Return the top k elements for a query') | |
parser.add_argument('--name', required=False, default="drqa", | |
help='Name of the retriever') | |
parser.add_argument('--threads', required=False, type=int, default=8, | |
help='Num of threads') | |
args = parser.parse_args() | |
main(args) | |