davanstrien HF staff commited on
Commit
ec8173c
·
1 Parent(s): 7e16395

switch to molmo

Browse files
Files changed (1) hide show
  1. app.py +72 -63
app.py CHANGED
@@ -1,15 +1,16 @@
1
- import subprocess # 🥲
 
 
 
 
 
 
2
 
3
- subprocess.run(
4
- "pip install flash-attn --no-build-isolation",
5
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
- shell=True,
7
- )
8
  import spaces
9
  import gradio as gr
10
-
11
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
12
- from qwen_vl_utils import process_vision_info
13
  import torch
14
  import os
15
  import json
@@ -18,15 +19,28 @@ from typing import Tuple
18
 
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
20
 
21
-
22
- model = Qwen2VLForConditionalGeneration.from_pretrained(
23
- "Qwen/Qwen2-VL-7B-Instruct",
24
- torch_dtype=torch.bfloat16,
25
- attn_implementation="flash_attention_2",
26
- device_map="auto",
 
 
 
 
 
 
27
  )
28
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
29
 
 
 
 
 
 
 
 
 
30
 
31
  class GeneralRetrievalQuery(BaseModel):
32
  broad_topical_query: str
@@ -36,7 +50,6 @@ class GeneralRetrievalQuery(BaseModel):
36
  visual_element_query: str
37
  visual_element_explanation: str
38
 
39
-
40
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
41
  if prompt_name != "general":
42
  raise ValueError("Only 'general' prompt is available in this version")
@@ -76,78 +89,74 @@ Generate the queries based on this image and provide the response in the specifi
76
 
77
  return prompt, GeneralRetrievalQuery
78
 
79
-
80
- # defined like this so we can later add more prompting options
81
  prompt, pydantic_model = get_retrieval_prompt("general")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  def _prep_data_for_input(image):
85
- messages = [
86
- {
87
- "role": "user",
88
- "content": [
89
- {
90
- "type": "image",
91
- "image": image,
92
- },
93
- {"type": "text", "text": prompt},
94
- ],
95
- }
96
- ]
97
-
98
- text = processor.apply_chat_template(
99
- messages, tokenize=False, add_generation_prompt=True
100
  )
101
 
102
- image_inputs, video_inputs = process_vision_info(messages)
103
-
104
- return processor(
105
- text=[text],
106
- images=image_inputs,
107
- videos=video_inputs,
108
- padding=True,
109
- return_tensors="pt",
110
- )
111
-
112
-
113
  @spaces.GPU
114
  def generate_response(image):
115
  inputs = _prep_data_for_input(image)
116
- inputs = inputs.to("cuda")
117
-
118
- generated_ids = model.generate(**inputs, max_new_tokens=200)
119
- generated_ids_trimmed = [
120
- out_ids[len(in_ids) :]
121
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
122
- ]
123
-
124
- output_text = processor.batch_decode(
125
- generated_ids_trimmed,
126
- skip_special_tokens=True,
127
- clean_up_tokenization_spaces=False,
128
  )
 
 
 
129
  try:
130
- return json.loads(output_text[0])
131
  except Exception:
132
  gr.Warning("Failed to parse JSON from output")
133
  return {}
134
 
135
-
136
  title = "ColPali fine-tuning Query Generator"
137
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
138
 
139
  To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
140
  To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
141
 
142
- One way in which we might go about generating such a dataset is to use an VLM to generate synthetic queries for us.
143
- This space uses the [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image.
144
 
145
  **Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
146
 
147
  This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models.
148
 
149
  If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
150
-
151
  """
152
 
153
  examples = [
@@ -163,4 +172,4 @@ demo = gr.Interface(
163
  description=description,
164
  examples=examples,
165
  )
166
- demo.launch()
 
1
+ # import subprocess # 🥲 need for flash attention in QWEN model
2
+
3
+ # subprocess.run(
4
+ # "pip install flash-attn --no-build-isolation",
5
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ # shell=True,
7
+ # )
8
 
 
 
 
 
 
9
  import spaces
10
  import gradio as gr
11
+ from transformers import AutoModelForCausalLM, AutoProcessor
12
+ # from transformers import Qwen2VLForConditionalGeneration # Uncomment when adding QWEN back
13
+ # from qwen_vl_utils import process_vision_info # Uncomment when adding QWEN back
14
  import torch
15
  import os
16
  import json
 
19
 
20
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
21
 
22
+ # Load Molmo model
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ 'allenai/Molmo-7B-D-0924',
25
+ trust_remote_code=True,
26
+ torch_dtype='auto',
27
+ device_map='auto'
28
+ )
29
+ processor = AutoProcessor.from_pretrained(
30
+ 'allenai/Molmo-7B-D-0924',
31
+ trust_remote_code=True,
32
+ torch_dtype='auto',
33
+ device_map='auto'
34
  )
 
35
 
36
+ # # Load Qwen model (commented out for now)
37
+ # qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
38
+ # "Qwen/Qwen2-VL-7B-Instruct",
39
+ # torch_dtype=torch.bfloat16,
40
+ # attn_implementation="flash_attention_2",
41
+ # device_map="auto",
42
+ # )
43
+ # qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
44
 
45
  class GeneralRetrievalQuery(BaseModel):
46
  broad_topical_query: str
 
50
  visual_element_query: str
51
  visual_element_explanation: str
52
 
 
53
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
54
  if prompt_name != "general":
55
  raise ValueError("Only 'general' prompt is available in this version")
 
89
 
90
  return prompt, GeneralRetrievalQuery
91
 
 
 
92
  prompt, pydantic_model = get_retrieval_prompt("general")
93
 
94
+ # def _prep_data_for_input_qwen(image):
95
+ # messages = [
96
+ # {
97
+ # "role": "user",
98
+ # "content": [
99
+ # {
100
+ # "type": "image",
101
+ # "image": image,
102
+ # },
103
+ # {"type": "text", "text": prompt},
104
+ # ],
105
+ # }
106
+ # ]
107
+ #
108
+ # text = qwen_processor.apply_chat_template(
109
+ # messages, tokenize=False, add_generation_prompt=True
110
+ # )
111
+ #
112
+ # image_inputs, video_inputs = process_vision_info(messages)
113
+ #
114
+ # return qwen_processor(
115
+ # text=[text],
116
+ # images=image_inputs,
117
+ # videos=video_inputs,
118
+ # padding=True,
119
+ # return_tensors="pt",
120
+ # )
121
 
122
  def _prep_data_for_input(image):
123
+ return processor.process(
124
+ images=[image],
125
+ text=prompt
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
 
 
 
 
 
 
 
 
 
 
 
 
128
  @spaces.GPU
129
  def generate_response(image):
130
  inputs = _prep_data_for_input(image)
131
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
132
+ output = model.generate_from_batch(
133
+ inputs,
134
+ gr.GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
135
+ tokenizer=processor.tokenizer
 
 
 
 
 
 
 
136
  )
137
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
138
+ output_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
139
+
140
  try:
141
+ return json.loads(output_text)
142
  except Exception:
143
  gr.Warning("Failed to parse JSON from output")
144
  return {}
145
 
 
146
  title = "ColPali fine-tuning Query Generator"
147
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
148
 
149
  To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
150
  To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
151
 
152
+ One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us.
153
+ This space uses the [allenai/Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) model to generate queries for a document, based on an input document image.
154
 
155
  **Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
156
 
157
  This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models.
158
 
159
  If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
 
160
  """
161
 
162
  examples = [
 
172
  description=description,
173
  examples=examples,
174
  )
175
+ demo.launch()