File size: 4,779 Bytes
100ea29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
---
license: mit
---

## MIRAGE

**Model Type:** MIRAGE is an innovative open-source visual-RAG model capable of processing over 10,000 images as input. It integrates a retriever and a large multimodal model (LMM) for enhanced performance.

**Key Features:**
- **Compressor:** Reduces data size by compressing image tokens by 18x per image, enabling efficient handling of large datasets.
- **Query-Aware Retriever:** Dynamically filters out irrelevant images to focus processing power on content that enhances task performance.
- **Multi-Image LMM:** Features a tailored pretraining and instruction tuning dataset, designed to optimize model performance across a range of multimodal tasks.

**Performance:**
- MIRAGE establishes a new benchmark in open-source performance on the [Visual Haystacks (VHs) benchmark](https://huggingface.co/datasets/tsunghanwu/visual_haystacks).
- Delivers robust results across various single- and multi-image question answering tasks, such as RETVQA, MMBench, MMVet, VQAv2, and more.

**Usage:**
Please refer to the installation guide on our GitHub repository to get started with MIRAGE: [Installation Guide](https://github.com/visual-haystacks/mirage)

**Additional Resources:**
For detailed information and updates, visit our project page: [Visual Haystacks Project](https://visual-haystacks.github.io/)

**Support:**
For questions or comments about the model, please open an issue on our GitHub page: [GitHub Issues](https://github.com/visual-haystacks/mirage/issues)

**Intended Use:**
MIRAGE is primarily intended for research into large multimodal models (LMMs), long-context modeling, and retrieval-augmented generation (RAG).

### Example Usage Code

```python
from PIL import Image
import argparse
import torch
import os

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.utils import disable_torch_init

@torch.inference_mode()
def run(model_path, image_paths, prompt, num_retrievals=1):
    '''
    Executes MIRAGE with specified inputs to generate descriptive text based on the provided images.
    
    Args:
        model_path (str): Path to the MIRAGE model, e.g., 'tsunghanwu/mirage-llama3.1-8.3B'
        image_paths (list): List of paths to image files, e.g., images in 'assets/example'
        prompt (str): Text prompt for image description, e.g., 'Here are a set of random images in my photo album. 
                      If you can find a cat, tell me what's the cat doing and what's its color.'
        num_retrievals (int): Maximum number of images to retrieve and pass to the LMM
    
    Returns:
        output_text (str): Descriptive text generated by the LMM
        output_ret (list): List of images retrieved by the model
    '''
    # Load the model and prepare the environment
    model_name = get_model_name_from_path(model_path)
    disable_torch_init()
    model_name = os.path.expanduser(model_name)
    tokenizer, model, image_processor, _ = \
        load_pretrained_model(model_path=model_path, model_base=None, model_name=model_name, device="cuda")
    model.eval_mode = True

    # Process the images
    clip_images = []
    for image_path in image_paths:
        image = Image.open(image_path). convert("RGB")
        image_tensor = process_images([image], image_processor, model.config)[0]
        image_tensor = image_tensor.to(dtype=torch.float16)
        clip_images.append(image_tensor)

    # Prepare text input and interaction
    qformer_text_input = tokenizer(prompt, return_tensors='pt')["input_ids"].to(model.device)
    N = len(clip_images)
    img_str = DEFAULT_IMAGE_TOKEN * N + "\n"
    inp = img_str + prompt
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    # Generate model output
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    tokenizer.pad_token_id = 128002
    batch_clip_images = [torch.stack(clip_images).to(model.device)]

    output_ret, output_ids = model.generate(
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        clip_images=batch_clip_images,
        qformer_text_input=qformer_text_input,
        relevance=None,
        num_retrieval=num_retrievals,
        do_sample=False,
        max_new_tokens=512,
        use_cache=True)

    # Process output
    output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    if not isinstance(output_ret[0], list):
        output_ret[0] = output_ret[0].tolist()
    return output_text, output_ret[0]
```