tsunghanwu commited on
Commit
100ea29
·
verified ·
1 Parent(s): f39eaf0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -3
README.md CHANGED
@@ -1,3 +1,106 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ ## MIRAGE
6
+
7
+ **Model Type:** MIRAGE is an innovative open-source visual-RAG model capable of processing over 10,000 images as input. It integrates a retriever and a large multimodal model (LMM) for enhanced performance.
8
+
9
+ **Key Features:**
10
+ - **Compressor:** Reduces data size by compressing image tokens by 18x per image, enabling efficient handling of large datasets.
11
+ - **Query-Aware Retriever:** Dynamically filters out irrelevant images to focus processing power on content that enhances task performance.
12
+ - **Multi-Image LMM:** Features a tailored pretraining and instruction tuning dataset, designed to optimize model performance across a range of multimodal tasks.
13
+
14
+ **Performance:**
15
+ - MIRAGE establishes a new benchmark in open-source performance on the [Visual Haystacks (VHs) benchmark](https://huggingface.co/datasets/tsunghanwu/visual_haystacks).
16
+ - Delivers robust results across various single- and multi-image question answering tasks, such as RETVQA, MMBench, MMVet, VQAv2, and more.
17
+
18
+ **Usage:**
19
+ Please refer to the installation guide on our GitHub repository to get started with MIRAGE: [Installation Guide](https://github.com/visual-haystacks/mirage)
20
+
21
+ **Additional Resources:**
22
+ For detailed information and updates, visit our project page: [Visual Haystacks Project](https://visual-haystacks.github.io/)
23
+
24
+ **Support:**
25
+ For questions or comments about the model, please open an issue on our GitHub page: [GitHub Issues](https://github.com/visual-haystacks/mirage/issues)
26
+
27
+ **Intended Use:**
28
+ MIRAGE is primarily intended for research into large multimodal models (LMMs), long-context modeling, and retrieval-augmented generation (RAG).
29
+
30
+ ### Example Usage Code
31
+
32
+ ```python
33
+ from PIL import Image
34
+ import argparse
35
+ import torch
36
+ import os
37
+
38
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
39
+ from llava.conversation import conv_templates
40
+ from llava.model.builder import load_pretrained_model
41
+ from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
42
+ from llava.utils import disable_torch_init
43
+
44
+ @torch.inference_mode()
45
+ def run(model_path, image_paths, prompt, num_retrievals=1):
46
+ '''
47
+ Executes MIRAGE with specified inputs to generate descriptive text based on the provided images.
48
+
49
+ Args:
50
+ model_path (str): Path to the MIRAGE model, e.g., 'tsunghanwu/mirage-llama3.1-8.3B'
51
+ image_paths (list): List of paths to image files, e.g., images in 'assets/example'
52
+ prompt (str): Text prompt for image description, e.g., 'Here are a set of random images in my photo album.
53
+ If you can find a cat, tell me what's the cat doing and what's its color.'
54
+ num_retrievals (int): Maximum number of images to retrieve and pass to the LMM
55
+
56
+ Returns:
57
+ output_text (str): Descriptive text generated by the LMM
58
+ output_ret (list): List of images retrieved by the model
59
+ '''
60
+ # Load the model and prepare the environment
61
+ model_name = get_model_name_from_path(model_path)
62
+ disable_torch_init()
63
+ model_name = os.path.expanduser(model_name)
64
+ tokenizer, model, image_processor, _ = \
65
+ load_pretrained_model(model_path=model_path, model_base=None, model_name=model_name, device="cuda")
66
+ model.eval_mode = True
67
+
68
+ # Process the images
69
+ clip_images = []
70
+ for image_path in image_paths:
71
+ image = Image.open(image_path). convert("RGB")
72
+ image_tensor = process_images([image], image_processor, model.config)[0]
73
+ image_tensor = image_tensor.to(dtype=torch.float16)
74
+ clip_images.append(image_tensor)
75
+
76
+ # Prepare text input and interaction
77
+ qformer_text_input = tokenizer(prompt, return_tensors='pt')["input_ids"].to(model.device)
78
+ N = len(clip_images)
79
+ img_str = DEFAULT_IMAGE_TOKEN * N + "\n"
80
+ inp = img_str + prompt
81
+ conv.append_message(conv.roles[0], inp)
82
+ conv.append_message(conv.roles[1], None)
83
+ prompt = conv.get_prompt()
84
+
85
+ # Generate model output
86
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
87
+ tokenizer.pad_token_id = 128002
88
+ batch_clip_images = [torch.stack(clip_images).to(model.device)]
89
+
90
+ output_ret, output_ids = model.generate(
91
+ input_ids,
92
+ pad_token_id=tokenizer.pad_token_id,
93
+ clip_images=batch_clip_images,
94
+ qformer_text_input=qformer_text_input,
95
+ relevance=None,
96
+ num_retrieval=num_retrievals,
97
+ do_sample=False,
98
+ max_new_tokens=512,
99
+ use_cache=True)
100
+
101
+ # Process output
102
+ output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
103
+ if not isinstance(output_ret[0], list):
104
+ output_ret[0] = output_ret[0].tolist()
105
+ return output_text, output_ret[0]
106
+ ```