|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
import os |
|
import re |
|
import shutil |
|
import tarfile |
|
import logging |
|
from urllib.error import HTTPError, URLError |
|
from urllib.request import urlretrieve |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
from pyserini.encoded_query_info import QUERY_INFO |
|
from pyserini.encoded_corpus_info import CORPUS_INFO |
|
from pyserini.evaluate_script_info import EVALUATION_INFO |
|
from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class TqdmUpTo(tqdm): |
|
def update_to(self, b=1, bsize=1, tsize=None): |
|
""" |
|
b : int, optional |
|
Number of blocks transferred so far [default: 1]. |
|
bsize : int, optional |
|
Size of each block (in tqdm units) [default: 1]. |
|
tsize : int, optional |
|
Total size (in tqdm units). If [default: None] remains unchanged. |
|
""" |
|
if tsize is not None: |
|
self.total = tsize |
|
self.update(b * bsize - self.n) |
|
|
|
|
|
|
|
|
|
def compute_md5(file, block_size=2**20): |
|
m = hashlib.md5() |
|
with open(file, 'rb') as f: |
|
while True: |
|
buf = f.read(block_size) |
|
if not buf: |
|
break |
|
m.update(buf) |
|
return m.hexdigest() |
|
|
|
|
|
def download_url(url, save_dir, local_filename=None, md5=None, force=False, verbose=True): |
|
|
|
if not local_filename: |
|
filename = url.split('/')[-1] |
|
filename = re.sub('\\?dl=1$', '', filename) |
|
else: |
|
|
|
filename = local_filename |
|
|
|
destination_path = os.path.join(save_dir, filename) |
|
|
|
if verbose: |
|
print(f'Downloading {url} to {destination_path}...') |
|
|
|
|
|
|
|
if os.path.exists(destination_path): |
|
if verbose: |
|
print(f'{destination_path} already exists!') |
|
if not force: |
|
if verbose: |
|
print(f'Skipping download.') |
|
return destination_path |
|
if verbose: |
|
print(f'force=True, removing {destination_path}; fetching fresh copy...') |
|
os.remove(destination_path) |
|
|
|
with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename) as t: |
|
urlretrieve(url, filename=destination_path, reporthook=t.update_to) |
|
|
|
if md5: |
|
md5_computed = compute_md5(destination_path) |
|
assert md5_computed == md5, f'{destination_path} does not match checksum! Expecting {md5} got {md5_computed}.' |
|
|
|
return destination_path |
|
|
|
|
|
def get_cache_home(): |
|
custom_dir = os.environ.get("PYSERINI_CACHE") |
|
if custom_dir is not None and custom_dir != '': |
|
return custom_dir |
|
return os.path.expanduser(os.path.join(f'~{os.path.sep}.cache', "pyserini")) |
|
|
|
def download_and_unpack_index(url, index_directory='indexes', local_filename=False, |
|
force=False, verbose=True, prebuilt=False, md5=None): |
|
|
|
if not local_filename: |
|
index_name = url.split('/')[-1] |
|
else: |
|
|
|
index_name = local_filename |
|
|
|
index_name = re.sub('''.tar.gz.*$''', '', index_name) |
|
|
|
if prebuilt: |
|
index_directory = os.path.join(get_cache_home(), index_directory) |
|
index_path = os.path.join(index_directory, f'{index_name}.{md5}') |
|
|
|
if not os.path.exists(index_directory): |
|
os.makedirs(index_directory) |
|
|
|
local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz') |
|
|
|
|
|
if os.path.exists(local_tarball): |
|
os.remove(local_tarball) |
|
else: |
|
local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz') |
|
index_path = os.path.join(index_directory, f'{index_name}') |
|
|
|
|
|
|
|
if os.path.exists(index_path): |
|
if not force: |
|
if verbose: |
|
print(f'{index_path} already exists, skipping download.') |
|
return index_path |
|
if verbose: |
|
print(f'{index_path} already exists, but force=True, removing {index_path} and fetching fresh copy...') |
|
shutil.rmtree(index_path) |
|
|
|
print(f'Downloading index at {url}...') |
|
download_url(url, index_directory, local_filename=local_filename, verbose=False, md5=md5) |
|
|
|
if verbose: |
|
print(f'Extracting {local_tarball} into {index_path}...') |
|
try: |
|
tarball = tarfile.open(local_tarball) |
|
except: |
|
local_tarball = os.path.join(index_directory, f'{index_name}') |
|
tarball = tarfile.open(local_tarball) |
|
dirs_in_tarball = [member.name for member in tarball if member.isdir()] |
|
assert len(dirs_in_tarball), f"Detect multiple members ({', '.join(dirs_in_tarball)}) under the tarball {local_tarball}." |
|
tarball.extractall(index_directory) |
|
tarball.close() |
|
os.remove(local_tarball) |
|
|
|
if prebuilt: |
|
dir_in_tarball = dirs_in_tarball[0] |
|
if dir_in_tarball != index_name: |
|
logger.info(f"Renaming {index_directory}/{dir_in_tarball} into {index_directory}/{index_name}.") |
|
index_name = dir_in_tarball |
|
os.rename(os.path.join(index_directory, f'{index_name}'), index_path) |
|
|
|
return index_path |
|
|
|
|
|
def check_downloaded(index_name): |
|
if index_name in TF_INDEX_INFO: |
|
target_index = TF_INDEX_INFO[index_name] |
|
elif index_name in IMPACT_INDEX_INFO: |
|
target_index = IMPACT_INDEX_INFO[index_name] |
|
else: |
|
target_index = FAISS_INDEX_INFO[index_name] |
|
index_url = target_index['urls'][0] |
|
index_md5 = target_index['md5'] |
|
index_name = index_url.split('/')[-1] |
|
index_name = re.sub('''.tar.gz.*$''', '', index_name) |
|
index_directory = os.path.join(get_cache_home(), 'indexes') |
|
index_path = os.path.join(index_directory, f'{index_name}.{index_md5}') |
|
|
|
return os.path.exists(index_path) |
|
|
|
|
|
def get_sparse_indexes_info(): |
|
df = pd.DataFrame.from_dict({**TF_INDEX_INFO, **IMPACT_INDEX_INFO}) |
|
for index in df.keys(): |
|
df[index]['downloaded'] = check_downloaded(index) |
|
|
|
with pd.option_context('display.max_rows', None, 'display.max_columns', |
|
None, 'display.max_colwidth', None, 'display.colheader_justify', 'left'): |
|
print(df) |
|
|
|
|
|
def get_impact_indexes_info(): |
|
df = pd.DataFrame.from_dict(IMPACT_INDEX_INFO) |
|
for index in df.keys(): |
|
df[index]['downloaded'] = check_downloaded(index) |
|
|
|
with pd.option_context('display.max_rows', None, 'display.max_columns', |
|
None, 'display.max_colwidth', None, 'display.colheader_justify', 'left'): |
|
print(df) |
|
|
|
|
|
def get_dense_indexes_info(): |
|
df = pd.DataFrame.from_dict(FAISS_INDEX_INFO) |
|
for index in df.keys(): |
|
df[index]['downloaded'] = check_downloaded(index) |
|
|
|
with pd.option_context('display.max_rows', None, 'display.max_columns', |
|
None, 'display.max_colwidth', None, 'display.colheader_justify', 'left'): |
|
print(df) |
|
|
|
|
|
def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): |
|
if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO: |
|
raise ValueError(f'Unrecognized index name {index_name}') |
|
if index_name in TF_INDEX_INFO: |
|
target_index = TF_INDEX_INFO[index_name] |
|
elif index_name in IMPACT_INDEX_INFO: |
|
target_index = IMPACT_INDEX_INFO[index_name] |
|
else: |
|
target_index = FAISS_INDEX_INFO[index_name] |
|
index_md5 = target_index['md5'] |
|
for url in target_index['urls']: |
|
local_filename = target_index['filename'] if 'filename' in target_index else None |
|
try: |
|
return download_and_unpack_index(url, local_filename=local_filename, |
|
prebuilt=True, md5=index_md5, verbose=verbose) |
|
except (HTTPError, URLError) as e: |
|
print(f'Unable to download pre-built index at {url}, trying next URL...') |
|
raise ValueError(f'Unable to download pre-built index at any known URLs.') |
|
|
|
|
|
def download_encoded_queries(query_name, force=False, verbose=True, mirror=None): |
|
if query_name not in QUERY_INFO: |
|
raise ValueError(f'Unrecognized query name {query_name}') |
|
query_md5 = QUERY_INFO[query_name]['md5'] |
|
for url in QUERY_INFO[query_name]['urls']: |
|
try: |
|
return download_and_unpack_index(url, index_directory='queries', prebuilt=True, md5=query_md5) |
|
except (HTTPError, URLError) as e: |
|
print(f'Unable to download encoded query at {url}, trying next URL...') |
|
raise ValueError(f'Unable to download encoded query at any known URLs.') |
|
|
|
|
|
def download_encoded_corpus(corpus_name, force=False, verbose=True, mirror=None): |
|
if corpus_name not in CORPUS_INFO: |
|
raise ValueError(f'Unrecognized corpus name {corpus_name}') |
|
corpus_md5 = CORPUS_INFO[corpus_name]['md5'] |
|
for url in CORPUS_INFO[corpus_name]['urls']: |
|
local_filename = CORPUS_INFO[corpus_name]['filename'] if 'filename' in CORPUS_INFO[corpus_name] else None |
|
try: |
|
return download_and_unpack_index(url, local_filename=local_filename, index_directory='corpus', prebuilt=True, md5=corpus_md5) |
|
except (HTTPError, URLError) as e: |
|
print(f'Unable to download encoded corpus at {url}, trying next URL...') |
|
raise ValueError(f'Unable to download encoded corpus at any known URLs.') |
|
|
|
|
|
def download_evaluation_script(evaluation_name, force=False, verbose=True, mirror=None): |
|
if evaluation_name not in EVALUATION_INFO: |
|
raise ValueError(f'Unrecognized evaluation name {evaluation_name}') |
|
for url in EVALUATION_INFO[evaluation_name]['urls']: |
|
try: |
|
save_dir = os.path.join(get_cache_home(), 'eval') |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
return download_url(url, save_dir=save_dir) |
|
except HTTPError: |
|
print(f'Unable to download evaluation script at {url}, trying next URL...') |
|
raise ValueError(f'Unable to download evaluation script at any known URLs.') |
|
|
|
|
|
def get_sparse_index(index_name): |
|
if index_name not in FAISS_INDEX_INFO: |
|
raise ValueError(f'Unrecognized index name {index_name}') |
|
return FAISS_INDEX_INFO[index_name]["texts"] |
|
|