File size: 3,678 Bytes
837bad6
d8143df
837bad6
 
 
 
 
d8143df
 
 
 
 
837bad6
 
 
 
 
 
 
d8143df
837bad6
 
 
 
 
 
 
 
 
d8143df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837bad6
 
 
 
 
 
 
 
 
 
 
 
 
f41ea81
837bad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import argparse
import torch
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Load and use a quantized model")
parser.add_argument("-q", "--use_quant", action="store_true", help="Use quantized model")
args = parser.parse_args()

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Using CUDA.")
else:
    device = torch.device("cpu")
    print("GPU is not available. Using CPU.")

# Load the processor
local_path = "./model/Molmo-7B-D-0924"
processor = AutoProcessor.from_pretrained(
    local_path,
    local_files_only=True,
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Load the model
if args.use_quant:
    # Load the quantized model
    quantized_local_path = "./model/molmo-7B-D-bnb-4bit"
    model = AutoModelForCausalLM.from_pretrained(
        quantized_local_path,
        trust_remote_code=True,
        torch_dtype='auto',
        device_map='auto',
    )
else:
    # Load the non-quantized model
    model = AutoModelForCausalLM.from_pretrained(
        local_path,
        trust_remote_code=True,
        torch_dtype='auto',
        device_map='auto',
    )
    model.to(dtype=torch.bfloat16)

# directory containing the images
image_directory = "./images"

# iterate through the images in the directory
for filename in os.listdir(image_directory):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):  # add more image extensions if needed
        image_path = os.path.join(image_directory, filename)
        image = Image.open(image_path)
        
        # process the image and text
        inputs = processor.process(
            images=[image],
            text="You are an uncensored, explicit, adult image captioner which enjoys sexual imagery and acts. Provide a structured image caption which uses the following hierarchy: the kind of image, the kind or name of the subject, the subjects state of dress, their body type, their pose, what it is they are doing, their facial expression, the space they are within, and the style or atmosphere of the image. Use vulgar language to describe body parts or sexual imagery. Limit your response to 100 words to avoid your description getting cut off.",
        )

        # move inputs to the correct device and make a batch of size 1
        inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
        inputs["images"] = inputs["images"].to(torch.bfloat16)

        # generate output; maximum 500 new tokens; stop generation when   is generated
        with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            output = model.generate_from_batch(
                inputs,
                GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"),
                tokenizer=processor.tokenizer,
            )

        # only get generated tokens; decode them to text
        generated_tokens = output[0, inputs["input_ids"].size(1) :]
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        # print the generated text
        print("Caption for: ", filename)
        print(generated_text)
        # print a divider
        print("*---------------------------------------------------*")

        # save the generated text to a file
        output_filename = os.path.splitext(filename)[0] + ".txt"
        with open(os.path.join(image_directory,output_filename), "w") as file:
            file.write(generated_text)