Integrate Sentence Transformers, prevent manual tokenizer EOS

#1
by tomaarsen HF staff - opened

Hello!

Pull Request overview

  • Integrate with Sentence Transformers (+ README updated, added Sentence Transformers tag to make this model easier to find)
  • Update the tokenizer.json TemplateProcessing so the EOS is always appended.
    • Simplify modeling_drama.py _tokenize as the EOS is now handled automatically.
  • Rename self.forwardto self.encode in modeling_drama.py: this allows for ST to work, as it uses its own pooling.

Details

This is a companion PR to https://huggingface.co/facebook/drama-base/discussions/1, which explains the changes made here in much more detail. To avoid copying the same text in multiple places, I'd recommend checking out that PR.

Note: There are some differences with that PR:

  • self.max_seq_len = 128000
    This model has the maximum sequence length set to 128k. I've kept this as-is, and reused 128k in sentence_bert_config.json.

Additionally, consider running this to experiment:

import torch
from transformers import AutoTokenizer, AutoModel


queries = [
    'What percentage of the Earth\'s atmosphere is oxygen?',
    '意大利首都是哪里?',
]
documents = [
    "The amount of oxygen in the atmosphere has fluctuated over the last 600 million years, reaching a peak of 35% during the Carboniferous period, significantly higher than today's 21%.",
    "羅馬是欧洲国家意大利首都和罗马首都广域市的首府及意大利全国的政治、经济、文化和交通中心,位于意大利半島中部的台伯河下游平原地,建城初期在七座小山丘上,故又名“七丘之城”。按城市范围内的人口计算,罗马是意大利人口最多的城市,也是欧盟人口第三多的城市。",
]

model_name = "facebook/drama-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="refs/pr/1")
model = AutoModel.from_pretrained(model_name, revision="refs/pr/1", trust_remote_code=True).to(device)

query_embs = model.encode_queries(tokenizer, queries)
doc_embs = model.encode_documents(tokenizer, documents)

scores = query_embs @ doc_embs.T
print(scores.tolist())
# Expected output: [[0.5062, 0.1475], [0.1837, 0.6331]]

# An extra test:
tokenized = tokenizer("This is my text")
decoded = tokenizer.decode(tokenized["input_ids"])
print(decoded)
# <|begin_of_text|>This is my text<|end_of_text|>
  • Tom Aarsen
tomaarsen changed pull request status to open
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment