Spaces:
Build error
Build error
from datasets import load_dataset | |
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder | |
from general_utils import embed_passages, embed_passages_haystack | |
import faiss | |
import argparse | |
import os | |
from haystack.nodes import DensePassageRetriever | |
from haystack.document_stores import InMemoryDocumentStore | |
os.environ["OMP_NUM_THREADS"] = "8" | |
def create_faiss_index(args): | |
minchars = 200 | |
dims = 128 | |
dpr = DensePassageRetriever( | |
document_store=InMemoryDocumentStore(), | |
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
max_seq_len_query=64, | |
max_seq_len_passage=256, | |
batch_size=512, | |
) | |
dataset = load_dataset( | |
"IIC/spanish_biomedical_crawled_corpus", split="train" | |
) | |
dataset = dataset.filter(lambda example: len(example["text"]) > minchars) | |
def embed_passages_retrieval(examples): | |
return embed_passages_haystack(dpr, examples) | |
dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192) | |
dataset.add_faiss_index( | |
column="embeddings", | |
string_factory="OPQ64_128,IVF4898,PQ64x4fsr", | |
train_size=len(dataset), | |
) | |
dataset.save_faiss_index("embeddings", args.index_file_name) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file") | |
parser.add_argument( | |
"--ctx_encoder_name", | |
default="IIC/dpr-spanish-passage_encoder-squades-base", | |
help="Encoding model to use for passage encoding", | |
) | |
parser.add_argument( | |
"--index_file_name", | |
default="dpr_index_bio_splitted.faiss", | |
help="Faiss index file with passage embeddings", | |
) | |
parser.add_argument( | |
"--device", default="cuda:0", help="The device to index data on." | |
) | |
main_args, _ = parser.parse_known_args() | |
create_faiss_index(main_args) | |