Integrate Sentence Transformers, prevent manual tokenizer EOS
#1
by
tomaarsen
HF staff
- opened
Hello!
Pull Request overview
- Integrate with Sentence Transformers (+ README updated, added Sentence Transformers tag to make this model easier to find)
- Update the
tokenizer.json
TemplateProcessing
so the EOS is always appended.- Simplify
modeling_drama.py
_tokenize
as the EOS is now handled automatically.
- Simplify
- Rename
self.forward
toself.encode
inmodeling_drama.py
: this allows for ST to work, as it uses its own pooling.
Details
This is a companion PR to https://huggingface.co/facebook/drama-base/discussions/1, which explains the changes made here in much more detail. To avoid copying the same text in multiple places, I'd recommend checking out that PR.
Note: There are some differences with that PR:
self.max_seq_len = 128000
This model has the maximum sequence length set to 128k. I've kept this as-is, and reused 128k insentence_bert_config.json
.
Additionally, consider running this to experiment:
import torch
from transformers import AutoTokenizer, AutoModel
queries = [
'What percentage of the Earth\'s atmosphere is oxygen?',
'意大利首都是哪里?',
]
documents = [
"The amount of oxygen in the atmosphere has fluctuated over the last 600 million years, reaching a peak of 35% during the Carboniferous period, significantly higher than today's 21%.",
"羅馬是欧洲国家意大利首都和罗马首都广域市的首府及意大利全国的政治、经济、文化和交通中心,位于意大利半島中部的台伯河下游平原地,建城初期在七座小山丘上,故又名“七丘之城”。按城市范围内的人口计算,罗马是意大利人口最多的城市,也是欧盟人口第三多的城市。",
]
model_name = "facebook/drama-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="refs/pr/1")
model = AutoModel.from_pretrained(model_name, revision="refs/pr/1", trust_remote_code=True).to(device)
query_embs = model.encode_queries(tokenizer, queries)
doc_embs = model.encode_documents(tokenizer, documents)
scores = query_embs @ doc_embs.T
print(scores.tolist())
# Expected output: [[0.5062, 0.1475], [0.1837, 0.6331]]
# An extra test:
tokenized = tokenizer("This is my text")
decoded = tokenizer.decode(tokenized["input_ids"])
print(decoded)
# <|begin_of_text|>This is my text<|end_of_text|>
- Tom Aarsen
tomaarsen
changed pull request status to
open