quarterturn commited on
Commit
837bad6
·
1 Parent(s): 2a58c0b

first commit

Browse files
Files changed (7) hide show
  1. README.md +26 -3
  2. caption.py +74 -0
  3. example.png +0 -0
  4. main.py +207 -0
  5. model/Molmo-7B-D-0924 +1 -0
  6. requirements.txt +11 -0
  7. test-images.zip +3 -0
README.md CHANGED
@@ -1,3 +1,26 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ ---
4
+ Molmo 7B Flux Dev Image Captioner.
5
+ ![Screenshot](example.png)
6
+
7
+ A simple python and gradio script to use Molmo 7B for image captioning. The prompt is currently written to produce captions that work well for Flux Dev LoRA training, but you could adjust it to suit other models captioning style.
8
+
9
+ Install:
10
+ 1. create a python3 venv or use conda to create an environment, eg:
11
+ ``` conda create -n caption python=3.11 ```
12
+ 2. activate your environment, eg:
13
+ ``` conda activate caption ```
14
+ 3. install the dependencies
15
+ ``` pip3 install -r requirements.txt ```
16
+ 4. run the gradio version:
17
+ ``` python3 main.py ```
18
+ 1. create a zip file of images
19
+ 2. upload it
20
+ 3. process it
21
+ 4. click the button to download the caption zip file, the link is at the top of the page
22
+
23
+ run the command-line version:
24
+ ``` python3 caption.py ```
25
+ 1. make sure your images are in the "images" directory
26
+ 2. captions will be placed in the "images" directory
caption.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
6
+
7
+ if torch.cuda.is_available():
8
+ device = torch.device("cuda")
9
+ print("GPU is available. Using CUDA.")
10
+ else:
11
+ device = torch.device("cpu")
12
+ print("GPU is not available. Using CPU.")
13
+
14
+ # load the processor from local path
15
+ local_path = "./model/Molmo-7B-D-0924"
16
+ processor = AutoProcessor.from_pretrained(
17
+ local_path,
18
+ local_files_only=True,
19
+ trust_remote_code=True,
20
+ torch_dtype='auto',
21
+ device_map='auto'
22
+ )
23
+
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ local_path,
26
+ trust_remote_code=True,
27
+ torch_dtype='auto',
28
+ device_map='auto',
29
+ )
30
+
31
+
32
+ model.to(dtype=torch.bfloat16)
33
+
34
+ # directory containing the images
35
+ image_directory = "./images"
36
+
37
+ # iterate through the images in the directory
38
+ for filename in os.listdir(image_directory):
39
+ if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"): # add more image extensions if needed
40
+ image_path = os.path.join(image_directory, filename)
41
+ image = Image.open(image_path)
42
+
43
+ # process the image and text
44
+ inputs = processor.process(
45
+ images=[image],
46
+ text="Describe what you see in vivid detail, without line breaks. Include information about the pose of characters, their facial expression, their height, body type, weight, the position of their limbs, and the direction of their gaze, the color of their eyes, hair, and skin. If you know a person or place name, provide it. If you know the name of an artist who may have created what you see, provide that. Do not provide opinions or value judgements. Limit your response to 276 words to avoid your description getting cut off.",
47
+ )
48
+
49
+ # move inputs to the correct device and make a batch of size 1
50
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
51
+ inputs["images"] = inputs["images"].to(torch.bfloat16)
52
+
53
+ # generate output; maximum 500 new tokens; stop generation when is generated
54
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
55
+ output = model.generate_from_batch(
56
+ inputs,
57
+ GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"),
58
+ tokenizer=processor.tokenizer,
59
+ )
60
+
61
+ # only get generated tokens; decode them to text
62
+ generated_tokens = output[0, inputs["input_ids"].size(1) :]
63
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
64
+
65
+ # print the generated text
66
+ print("Caption for: ", filename)
67
+ print(generated_text)
68
+ # print a divider
69
+ print("*---------------------------------------------------*")
70
+
71
+ # save the generated text to a file
72
+ output_filename = os.path.splitext(filename)[0] + ".txt"
73
+ with open(os.path.join(image_directory,output_filename), "w") as file:
74
+ file.write(generated_text)
example.png ADDED
main.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # note: if you have a mix of Ampere and newer, and also older than Ampere GPUs, set the environment variable
2
+ # CUDA_VISIBLE_DEVICE=1,2,3 (for example) so that one or the other is excluded.
3
+ # otherwise the script may fail with a flash attention exception.
4
+
5
+ import gradio as gr
6
+ import os
7
+ import uuid
8
+ import zipfile
9
+ import torch
10
+ from PIL import Image
11
+ import requests
12
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
13
+ from io import BytesIO
14
+ import base64
15
+ import atexit
16
+ import shutil
17
+
18
+
19
+ def cleanup_temp_files():
20
+ # Delete the subdirectories inside the "images" directory
21
+ if os.path.exists("images"):
22
+ for dir_name in os.listdir("images"):
23
+ dir_path = os.path.join("images", dir_name)
24
+ if os.path.isdir(dir_path):
25
+ shutil.rmtree(dir_path)
26
+
27
+ if torch.cuda.is_available():
28
+ device = torch.device("cuda")
29
+ print("GPU is available. Using CUDA.")
30
+ else:
31
+ device = torch.device("cpu")
32
+ print("GPU is not available. Using CPU.")
33
+
34
+ # load the processor from local path
35
+ local_path = "./model/Molmo-7B-D-0924"
36
+ #print("Loading processor from local path...")
37
+ processor = AutoProcessor.from_pretrained(
38
+ local_path,
39
+ local_files_only=True,
40
+ trust_remote_code=True,
41
+ torch_dtype='auto',
42
+ device_map='auto'
43
+ )
44
+ #print("Processor loaded.")
45
+
46
+ print("Loading model from local path...")
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ local_path,
49
+ trust_remote_code=True,
50
+ torch_dtype='auto',
51
+ device_map='auto',
52
+ )
53
+ #print("Model loaded.")
54
+
55
+ generation_config = GenerationConfig(max_new_tokens=300, stop_strings="<|endoftext|>")
56
+ bits_and_bytes_config = BitsAndBytesConfig()
57
+
58
+ # load the model in bf16 to reduce VRAM needed
59
+ model.to(dtype=torch.bfloat16)
60
+ #print("Model loaded in bf16")
61
+
62
+ def unzip_images(zip_file):
63
+ # Create a unique directory for extracted images inside the "images" directory
64
+ session_dir = os.path.join("images", str(uuid.uuid4()))
65
+ os.makedirs(session_dir, exist_ok=True)
66
+
67
+ # Extract images from the ZIP file to the session directory
68
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
69
+ for file_info in zip_ref.infolist():
70
+ if not file_info.is_dir() and not file_info.filename.startswith("__MACOSX") and not file_info.filename.startswith("."):
71
+ zip_ref.extract(file_info, session_dir)
72
+
73
+ # Get the list of image paths
74
+ image_paths = [os.path.join(session_dir, filename) for filename in os.listdir(session_dir) if filename.lower().endswith(('.jpg', '.jpeg', '.png'))]
75
+
76
+ # Read the image data as PIL Image objects for previews
77
+ image_data = []
78
+ for image_path in image_paths:
79
+ image = Image.open(image_path)
80
+ image.thumbnail((128, 128)) # Resize the image to a maximum size of 128x128 pixels
81
+ image_data.append(image)
82
+
83
+ # Return the list of image paths and resized image data for previews
84
+ return image_paths, image_data
85
+
86
+ def generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config):
87
+ # generate a caption and return it
88
+ caption = f"Caption for {image_path}"
89
+
90
+ print("Processing ", image_path)
91
+
92
+ image = Image.open(image_path)
93
+ # process the image and text
94
+ inputs = processor.process(
95
+ images=[image],
96
+ text="Describe what you see in vivid detail, without line breaks. Include information about the pose of characters, their facial expression, their height, body type, weight, the position of their limbs, and the direction of their gaze, the color of their eyes, hair, and skin. If you know a person or place name, provide it. If you know the name of an artist who may have created what you see, provide that. Do not provide opinions or value judgements. Limit your response to 276 words to avoid your description getting cut off.",
97
+ )
98
+
99
+ # move inputs to the correct device and make a batch of size 1
100
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
101
+ inputs["images"] = inputs["images"].to(torch.bfloat16)
102
+
103
+ # generate output; maximum 500 new tokens; stop generation when is generated
104
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
105
+ output = model.generate_from_batch(
106
+ inputs,
107
+ GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"),
108
+ tokenizer=processor.tokenizer,
109
+ )
110
+
111
+ # only get generated tokens; decode them to text
112
+ generated_tokens = output[0, inputs["input_ids"].size(1) :]
113
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
114
+
115
+ # return the generated text
116
+ return generated_text
117
+
118
+ def process_images(image_paths, image_data):
119
+ captions = []
120
+ session_dir = os.path.dirname(image_paths[0])
121
+
122
+ for image_path in image_paths:
123
+ filename = os.path.basename(image_path) # Add this line to get the filename
124
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
125
+ # Process the image using the loaded model
126
+ # Use the loaded model to generate the caption
127
+ caption = generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config)
128
+ captions.append(caption)
129
+
130
+ # Save the caption to a text file
131
+ with open(os.path.join(session_dir, f"{os.path.splitext(filename)[0]}.txt"), 'w') as f:
132
+ f.write(caption)
133
+
134
+ # Create a ZIP file containing the caption text files
135
+ zip_filename = f"{session_dir}.zip"
136
+ with zipfile.ZipFile(zip_filename, 'w') as zip_ref:
137
+ for filename in os.listdir(session_dir):
138
+ if filename.lower().endswith('.txt'):
139
+ zip_ref.write(os.path.join(session_dir, filename), filename)
140
+
141
+ # Delete the session directory and its contents
142
+ for filename in os.listdir(session_dir):
143
+ os.remove(os.path.join(session_dir, filename))
144
+ os.rmdir(session_dir)
145
+
146
+ return captions, zip_filename, image_paths
147
+
148
+ def format_captioned_image(image, caption):
149
+ buffered = BytesIO()
150
+ image.save(buffered, format="JPEG")
151
+ encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
152
+
153
+ return f"<img src='data:image/jpeg;base64,{encoded_image}' style='width: 128px; height: 128px; object-fit: cover; margin-right: 8px;' /><span>{caption}</span>"
154
+
155
+ def process_images_and_update_gallery(zip_file):
156
+ image_paths, image_data = unzip_images(zip_file)
157
+ captions, zip_filename, image_paths = process_images(image_paths, image_data)
158
+ image_captions = [format_captioned_image(img, caption) for img, caption in zip(image_data, captions)]
159
+ return gr.Markdown("\n".join(image_captions)), zip_filename
160
+
161
+ def main():
162
+ # Register the cleanup function to be called on program exit
163
+ atexit.register(cleanup_temp_files)
164
+
165
+ with gr.Blocks(css="""
166
+ .captioned-image-gallery {
167
+ display: grid;
168
+ grid-template-columns: repeat(2, 1fr);
169
+ grid-gap: 16px;
170
+ }
171
+ """) as blocks:
172
+ zip_file_input = gr.File(label="Upload ZIP file containing images")
173
+ image_gallery = gr.Markdown(label="Image Previews")
174
+ submit_button = gr.Button("Submit")
175
+ zip_download_button = gr.Button("Download Caption ZIP", visible=False)
176
+ zip_filename = gr.State("")
177
+
178
+ zip_file_input.upload(
179
+ lambda zip_file: "\n".join(format_captioned_image(img, "") for img in unzip_images(zip_file)[1]),
180
+ inputs=zip_file_input,
181
+ outputs=image_gallery
182
+ )
183
+
184
+ submit_button.click(
185
+ process_images_and_update_gallery,
186
+ inputs=[zip_file_input],
187
+ outputs=[image_gallery, zip_filename]
188
+ )
189
+
190
+ zip_filename.change(
191
+ lambda zip_filename: gr.update(visible=True),
192
+ inputs=zip_filename,
193
+ outputs=zip_download_button
194
+ )
195
+
196
+ zip_download_button.click(
197
+ lambda zip_filename: (gr.update(value=zip_filename), gr.update(visible=True), cleanup_temp_files()),
198
+ inputs=zip_filename,
199
+ outputs=[zip_file_input, zip_download_button]
200
+ )
201
+
202
+ blocks.launch(server_name='0.0.0.0')
203
+
204
+ if __name__ == "__main__":
205
+ main()
206
+
207
+
model/Molmo-7B-D-0924 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 90426556d5eb7c123eb4368dd1768e8e77f624af
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ uuid
3
+ bitsandbytes
4
+ accelerate
5
+ transformers
6
+ torch
7
+ torchvision
8
+ Pillow
9
+ requests
10
+ einops
11
+ flash-attn
test-images.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a84383fcb27d0be0006744b76c97e77d2a45e852d4f17ae29eff2f8346b4923
3
+ size 3069789