genre-kilt / README.md
nicoladecao's picture
Update README.md
72095fe
|
raw
history blame
2.91 kB
metadata
language:
  - en
tags:
  - retrieval
  - entity-retrieval
  - named-entity-disambiguation
  - entity-disambiguation
  - named-entity-linking
  - entity-linking
  - text2text-generation

GENRE

The GENRE (Generative ENtity REtrieval) system as presented in Autoregressive Entity Retrieval implemented in pytorch.

In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned BART architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the facebookresearch/GENRE repository using fairseq (the transformers models are obtained with a conversion script similar to this.

This model was trained on the full training set of KILT (i.e., 11 datasets for fact-checking, entity-linking, slot filling, dialogue, open-domain extractive and abstractive QA).

BibTeX entry and citation info

Please consider citing our works if you use code from this repository.

@inproceedings{decao2020autoregressive,
  title={Autoregressive Entity Retrieval},
  author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
  booktitle={International Conference on Learning Representations},
  url={https://openreview.net/forum?id=5k8F6UU39V},
  year={2021}
}

Usage

Here is an example of generation for Wikipedia page retrieval for open-domain fact-checking:

import pickle
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# OPTIONAL: load the prefix tree (trie), you need to additionally download
# https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and 
# https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl
# from trie import Trie
# with open("kilt_titles_trie_dict.pkl", "rb") as f:
#     trie = Trie.load_from_dict(pickle.load(f))

tokenizer = AutoTokenizer.from_pretrained("facebook/genre-kilt")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/genre-kilt").eval()

sentences = ["Einstein was a German physicist."]

outputs = model.generate(
    **tokenizer(sentences, return_tensors="pt"),
    num_beams=5,
    num_return_sequences=5,
    # OPTIONAL: use constrained beam search
    # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)

tokenizer.batch_decode(outputs, skip_special_tokens=True)

which outputs the following top-5 predictions (using constrained beam search)

['Albert Einstein',
 'Erwin Schrödinger',
 'Werner Bruschke',
 'Werner von Habsburg',
 'Werner von Moltke']