|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import os |
|
import uuid |
|
import zipfile |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig |
|
from io import BytesIO |
|
import base64 |
|
import atexit |
|
import shutil |
|
|
|
|
|
def cleanup_temp_files(): |
|
|
|
if os.path.exists("images"): |
|
for dir_name in os.listdir("images"): |
|
dir_path = os.path.join("images", dir_name) |
|
if os.path.isdir(dir_path): |
|
shutil.rmtree(dir_path) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print("GPU is available. Using CUDA.") |
|
else: |
|
device = torch.device("cpu") |
|
print("GPU is not available. Using CPU.") |
|
|
|
|
|
local_path = "./model/Molmo-7B-D-0924" |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
local_path, |
|
local_files_only=True, |
|
trust_remote_code=True, |
|
torch_dtype='auto', |
|
device_map='auto' |
|
) |
|
|
|
|
|
print("Loading model from local path...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
local_path, |
|
trust_remote_code=True, |
|
torch_dtype='auto', |
|
device_map='auto', |
|
) |
|
|
|
|
|
generation_config = GenerationConfig(max_new_tokens=300, stop_strings="<|endoftext|>") |
|
bits_and_bytes_config = BitsAndBytesConfig() |
|
|
|
|
|
model.to(dtype=torch.bfloat16) |
|
|
|
|
|
def unzip_images(zip_file): |
|
|
|
session_dir = os.path.join("images", str(uuid.uuid4())) |
|
os.makedirs(session_dir, exist_ok=True) |
|
|
|
|
|
with zipfile.ZipFile(zip_file, 'r') as zip_ref: |
|
for file_info in zip_ref.infolist(): |
|
if not file_info.is_dir() and not file_info.filename.startswith("__MACOSX") and not file_info.filename.startswith("."): |
|
zip_ref.extract(file_info, session_dir) |
|
|
|
|
|
image_paths = [os.path.join(session_dir, filename) for filename in os.listdir(session_dir) if filename.lower().endswith(('.jpg', '.jpeg', '.png'))] |
|
|
|
|
|
image_data = [] |
|
for image_path in image_paths: |
|
image = Image.open(image_path) |
|
image.thumbnail((128, 128)) |
|
image_data.append(image) |
|
|
|
|
|
return image_paths, image_data |
|
|
|
def generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config): |
|
|
|
caption = f"Caption for {image_path}" |
|
|
|
print("Processing ", image_path) |
|
|
|
image = Image.open(image_path) |
|
|
|
inputs = processor.process( |
|
images=[image], |
|
text="Describe what you see in vivid detail, without line breaks. Include information about the pose of characters, their facial expression, their height, body type, weight, the position of their limbs, and the direction of their gaze, the color of their eyes, hair, and skin. If you know a person or place name, provide it. If you know the name of an artist who may have created what you see, provide that. Do not provide opinions or value judgements. Limit your response to 276 words to avoid your description getting cut off.", |
|
) |
|
|
|
|
|
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} |
|
inputs["images"] = inputs["images"].to(torch.bfloat16) |
|
|
|
|
|
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): |
|
output = model.generate_from_batch( |
|
inputs, |
|
GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"), |
|
tokenizer=processor.tokenizer, |
|
) |
|
|
|
|
|
generated_tokens = output[0, inputs["input_ids"].size(1) :] |
|
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
return generated_text |
|
|
|
def process_images(image_paths, image_data): |
|
captions = [] |
|
session_dir = os.path.dirname(image_paths[0]) |
|
|
|
for image_path in image_paths: |
|
filename = os.path.basename(image_path) |
|
if filename.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
|
|
|
|
caption = generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config) |
|
captions.append(caption) |
|
|
|
|
|
with open(os.path.join(session_dir, f"{os.path.splitext(filename)[0]}.txt"), 'w') as f: |
|
f.write(caption) |
|
|
|
|
|
zip_filename = f"{session_dir}.zip" |
|
with zipfile.ZipFile(zip_filename, 'w') as zip_ref: |
|
for filename in os.listdir(session_dir): |
|
if filename.lower().endswith('.txt'): |
|
zip_ref.write(os.path.join(session_dir, filename), filename) |
|
|
|
|
|
for filename in os.listdir(session_dir): |
|
os.remove(os.path.join(session_dir, filename)) |
|
os.rmdir(session_dir) |
|
|
|
return captions, zip_filename, image_paths |
|
|
|
def format_captioned_image(image, caption): |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return f"<img src='data:image/jpeg;base64,{encoded_image}' style='width: 128px; height: 128px; object-fit: cover; margin-right: 8px;' /><span>{caption}</span>" |
|
|
|
def process_images_and_update_gallery(zip_file): |
|
image_paths, image_data = unzip_images(zip_file) |
|
captions, zip_filename, image_paths = process_images(image_paths, image_data) |
|
image_captions = [format_captioned_image(img, caption) for img, caption in zip(image_data, captions)] |
|
return gr.Markdown("\n".join(image_captions)), zip_filename |
|
|
|
def main(): |
|
|
|
atexit.register(cleanup_temp_files) |
|
|
|
with gr.Blocks(css=""" |
|
.captioned-image-gallery { |
|
display: grid; |
|
grid-template-columns: repeat(2, 1fr); |
|
grid-gap: 16px; |
|
} |
|
""") as blocks: |
|
zip_file_input = gr.File(label="Upload ZIP file containing images") |
|
image_gallery = gr.Markdown(label="Image Previews") |
|
submit_button = gr.Button("Submit") |
|
zip_download_button = gr.Button("Download Caption ZIP", visible=False) |
|
zip_filename = gr.State("") |
|
|
|
zip_file_input.upload( |
|
lambda zip_file: "\n".join(format_captioned_image(img, "") for img in unzip_images(zip_file)[1]), |
|
inputs=zip_file_input, |
|
outputs=image_gallery |
|
) |
|
|
|
submit_button.click( |
|
process_images_and_update_gallery, |
|
inputs=[zip_file_input], |
|
outputs=[image_gallery, zip_filename] |
|
) |
|
|
|
zip_filename.change( |
|
lambda zip_filename: gr.update(visible=True), |
|
inputs=zip_filename, |
|
outputs=zip_download_button |
|
) |
|
|
|
zip_download_button.click( |
|
lambda zip_filename: (gr.update(value=zip_filename), gr.update(visible=True), cleanup_temp_files()), |
|
inputs=zip_filename, |
|
outputs=[zip_file_input, zip_download_button] |
|
) |
|
|
|
blocks.launch(server_name='0.0.0.0') |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|