Jack Morris
commited on
Commit
•
546fe43
1
Parent(s):
6b3423f
add model use information
Browse files
README.md
CHANGED
@@ -8645,4 +8645,96 @@ model-index:
|
|
8645 |
---
|
8646 |
# Contextual Document Embeddings (CDE)
|
8647 |
|
8648 |
-
Our new model that naturally integrates "context tokens" into the embedding process.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8645 |
---
|
8646 |
# Contextual Document Embeddings (CDE)
|
8647 |
|
8648 |
+
Our new model that naturally integrates "context tokens" into the embedding process.
|
8649 |
+
|
8650 |
+
# Using `cde-small-v1`
|
8651 |
+
|
8652 |
+
Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
|
8653 |
+
|
8654 |
+
|
8655 |
+
#### Note on prefixes
|
8656 |
+
|
8657 |
+
*Nota bene*: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can prepend the following strings to queries & documents:
|
8658 |
+
|
8659 |
+
```python
|
8660 |
+
query_prefix = "search_query: "
|
8661 |
+
document_prefix = "search_document: "
|
8662 |
+
```
|
8663 |
+
|
8664 |
+
## First stage
|
8665 |
+
|
8666 |
+
```python
|
8667 |
+
minicorpus_size = model.config.transductive_corpus_size
|
8668 |
+
minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
|
8669 |
+
assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
|
8670 |
+
minicorpus_docs = tokenizer(
|
8671 |
+
[document_prefix + doc for doc in minicorpus_docs],
|
8672 |
+
truncation=True,
|
8673 |
+
padding=True,
|
8674 |
+
max_length=512,
|
8675 |
+
return_tensors="pt"
|
8676 |
+
)
|
8677 |
+
import torch
|
8678 |
+
from tqdm.autonotebook import tqdm
|
8679 |
+
|
8680 |
+
batch_size = 32
|
8681 |
+
|
8682 |
+
dataset_embeddings = []
|
8683 |
+
for i in tqdm(range(0, len(minicorpus_docs["input_ids"]), batch_size)):
|
8684 |
+
minicorpus_docs_batch = {k: v[i:i+batch_size] for k,v in minicorpus_docs.items()}
|
8685 |
+
with torch.no_grad():
|
8686 |
+
dataset_embeddings.append(
|
8687 |
+
model.first_stage_model(**minicorpus_docs_batch)
|
8688 |
+
)
|
8689 |
+
|
8690 |
+
dataset_embeddings = torch.cat(dataset_embeddings)
|
8691 |
+
|
8692 |
+
## Running the second stage
|
8693 |
+
|
8694 |
+
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
|
8695 |
+
```python
|
8696 |
+
docs = tokenizer(
|
8697 |
+
[document_prefix + doc for doc in docs],
|
8698 |
+
truncation=True,
|
8699 |
+
padding=True,
|
8700 |
+
max_length=512,
|
8701 |
+
return_tensors="pt"
|
8702 |
+
).to(device)
|
8703 |
+
|
8704 |
+
with torch.no_grad():
|
8705 |
+
doc_embeddings = model.second_stage_model(
|
8706 |
+
input_ids=docs["input_ids"],
|
8707 |
+
attention_mask=docs["attention_mask"],
|
8708 |
+
dataset_embeddings=dataset_embeddings,
|
8709 |
+
)
|
8710 |
+
doc_embeddings /= doc_embeddings.norm(p=2, dim=1, keepdim=True)
|
8711 |
+
```
|
8712 |
+
|
8713 |
+
and the query prefix for queries:
|
8714 |
+
```python
|
8715 |
+
queries = queries.select(range(16))["text"]
|
8716 |
+
queries = tokenizer(
|
8717 |
+
[query_prefix + query for query in queries],
|
8718 |
+
truncation=True,
|
8719 |
+
padding=True,
|
8720 |
+
max_length=512,
|
8721 |
+
return_tensors="pt"
|
8722 |
+
).to(device)
|
8723 |
+
|
8724 |
+
with torch.no_grad():
|
8725 |
+
query_embeddings = model.second_stage_model(
|
8726 |
+
input_ids=queries["input_ids"],
|
8727 |
+
attention_mask=queries["attention_mask"],
|
8728 |
+
dataset_embeddings=dataset_embeddings,
|
8729 |
+
)
|
8730 |
+
query_embeddings /= query_embeddings.norm(p=2, dim=1, keepdim=True)
|
8731 |
+
```
|
8732 |
+
|
8733 |
+
these embeddings can be compared using dot product, since they're normalized.
|
8734 |
+
|
8735 |
+
### What if I don't know what my corpus will be ahead of time?
|
8736 |
+
|
8737 |
+
### Colab demo
|
8738 |
+
|
8739 |
+
We've set up a short demo in a Colab notebook showing how you might use our model:
|
8740 |
+
[Try our model in Colab:](https://colab.research.google.com/drive/1r8xwbp7_ySL9lP-ve4XMJAHjidB9UkbL?usp=sharing)
|