dse-phi35-vidore-ft / README.md
MrLight's picture
Update README.md
8c02be2 verified
metadata
language:
  - en
license: mit
library_name: Tevatron
tags:
  - vidore
datasets:
  - Tevatron/docmatix-ir
  - HuggingFaceM4/Docmatix
  - Tevatron/msmarco-passage-aug
  - vidore/colpali_train_set
  - Tevatron/wiki-ss-nq

DSE-Phi35-Vidore-ft

DSE-Phi3-Vidore-ft is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding (DSE) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss.

The model, Tevatron/dse-phi35-vidore-ft, is trained using 1/10 of the Tevatron/docmatix-ir dataset, a variant of HuggingFaceM4/Docmatix specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the docmatix-ir dataset page. Followed by finetuning on the (vidore)[https://huggingface.co/datasets/vidore/colpali_train_set] training set. The checkpoint is warmed up by text retrieval and webpage retrieval.

For example, DSE-Phi3-Vidore-V2 achieves 82.9 nDCG@5 on ViDoRE leaderboard.

How to train the model from scratch

Please see https://github.com/texttron/tevatron/tree/main/examples/dse

How to Use the Model

Load the Model and Processor

import torch
from transformers import AutoProcessor, AutoModelForCausalLM

processor = AutoProcessor.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0')

def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    sequence_lengths = attention_mask.sum(dim=1) - 1
    bs = last_hidden_state.shape[0]
    reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths]
    reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

Encode Text Query

queries = ["query: Where can we see Llama?</s>", "query: What is LLaMA model?</s>"]
query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0')
with torch.no_grad():
    output = model(**query_inputs, return_dict=True, output_hidden_states=True)
query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"])

Encode Document Screenshot

from PIL import Image
import requests
from io import BytesIO

# URLs of the images
url1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png"
url2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png"

# Download and open images
response1 = requests.get(url1)
response2 = requests.get(url2)

passage_image1 = Image.open(BytesIO(response1.content)).resize((1344, 1344))
passage_image2 = Image.open(BytesIO(response2.content)).resize((1344, 1344))

passage_images = [passage_image1, passage_image2]
passage_prompts = ["<|image_1|>\nWhat is shown in this image?</s>", "<|image_2|>\nWhat is shown in this image?</s>"]

# Process inputs and get embeddings
passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
passage_inputs['input_ids'] = passage_inputs['input_ids'].squeeze(0)
passage_inputs['attention_mask'] = passage_inputs['attention_mask'].squeeze(0)
passage_inputs['image_sizes'] = passage_inputs['image_sizes'].squeeze(0)
with torch.no_grad():
    output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

Compute Similarity

from torch.nn.functional import cosine_similarity
num_queries = query_embeddings.size(0)
num_passages = doc_embeddings.size(0)

for i in range(num_queries):
    query_embedding = query_embeddings[i].unsqueeze(0)
    similarities = cosine_similarity(query_embedding, doc_embeddings)
    print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")

Encode Document Text

This DSE checkpoint is warm-up with Tevatron/msmarco-passage-aug, thus the model can also effectively encode document as text input.

passage_prompts = [
  "The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.</s>",
  "Llama (acronym for Large Language Model Meta AI, and formerly stylized as LLaMA) is a family of autoregressive large language models (LLMs) released by Meta AI starting in February 2023.[2][3] The latest version is Llama 3.1, released in July 2024.[4]</s>"
]

passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
with torch.no_grad():
    output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

for i in range(num_queries):
    query_embedding = query_embeddings[i].unsqueeze(0)
    similarities = cosine_similarity(query_embedding, doc_embeddings)
    print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")

Citation

If you find this checkpoint is helpful, please consider cite Phi3, Docmatix and ViDoRe and our DSE work.