John6666 commited on
Commit
b46152e
โ€ข
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
9em124t2-499968/clip_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d7b0548d12fa649370896982c2af9d03d43285b782bd47639c96e6e0b29473c
3
+ size 1713067838
9em124t2-499968/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_project: joy-caption-1
2
+ device_batch_size: 2
3
+ batch_size: 256
4
+ learning_rate: 0.0002
5
+ warmup_samples: 18000
6
+ max_samples: 500000
7
+ save_every: 50000
8
+ test_every: 50000
9
+ use_amp: true
10
+ grad_scaler: true
11
+ lr_scheduler_type: cosine
12
+ min_lr_ratio: 0.0
13
+ allow_tf32: true
14
+ seed: 69
15
+ num_workers: 8
16
+ optimizer_type: adamw
17
+ adam_beta1: 0.9
18
+ adam_beta2: 0.999
19
+ adam_eps: 1.0e-08
20
+ adam_weight_decay: 0.0
21
+ clip_grad_norm: 1.0
22
+ dataset: fancyfeast/joy-captioning-20240917a
23
+ clip_model: google/siglip-so400m-patch14-384
24
+ text_model: meta-llama/Meta-Llama-3.1-8B
25
+ resume: null
26
+ gradient_checkpointing: false
27
+ test_size: 2048
28
+ grad_scaler_init: 65536.0
29
+ max_caption_length: 257
30
+ num_image_tokens: 32
31
+ adapter_type: mlp
32
+ text_model_dtype: bfloat16
33
+ pre_test: false
34
+ train_image_model: true
35
+ image_model_lr: null
36
+ train_lora: true
37
+ lora_r: 64
38
+ lora_alpha: 16
39
+ lora_dropout: 0.1
9em124t2-499968/image_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e53c3bf8df745a3c19ae3c70dbf9bf23cfdc8f3fdb937000a4eafd2a36914661
3
+ size 86067714
9em124t2-499968/text_model/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Meta-Llama-3.1-8B
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.12.0
9em124t2-499968/text_model/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Meta-Llama-3.1-8B",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 16,
14
+ "lora_dropout": 0.1,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 64,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "q_proj",
24
+ "v_proj"
25
+ ],
26
+ "task_type": "CAUSAL_LM",
27
+ "use_dora": false,
28
+ "use_rslora": false
29
+ }
9em124t2-499968/text_model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b48221de174ab0db7b46b4833118c5c0a4c2bf0b51b77b4cc4ab04651bd06cca
3
+ size 109069176
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ ---
6
+ # Image Captioning App
7
+
8
+ This is a mod of [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha) and [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one). Thanks to [dominic1021](https://huggingface.co/dominic1021).
9
+
10
+ # Notice: I will contribute to Wi-zz after shaping the code.
11
+
12
+ ## Overview
13
+
14
+ This application generates descriptive captions for images using advanced ML models. It processes single images or entire directories, leveraging CLIP and LLM models for accurate and contextual captions. It has NSFW captioning support with natural language. This is just an extension of the original author's efforts to improve performance. Their repo is located here: https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one.
15
+
16
+ ## Features
17
+
18
+ - Single image and batch processing
19
+ - Multiple directory support
20
+ - Custom output directory
21
+ - Adjustable batch size
22
+ - Progress tracking
23
+
24
+ ## Usage
25
+
26
+ | Command | Description |
27
+ |---------|-------------|
28
+ | `python app.py image.jpg` | Process a single image |
29
+ | `python app.py /path/to/directory` | Process all images in a directory |
30
+ | `python app.py /path/to/dir1 /path/to/dir2` | Process multiple directories |
31
+ | `python app.py /path/to/dir --output /path/to/output` | Specify output directory |
32
+ | `python app.py /path/to/dir --bs 8` | Set batch size (default: 4) |
33
+
34
+ ## Technical Details
35
+
36
+ - **Models**: CLIP (vision), LLM (language), custom ImageAdapter
37
+ - **Optimization**: CUDA-enabled GPU support
38
+ - **Error Handling**: Skips problematic images in batch processing
39
+
40
+ ## Requirements
41
+
42
+ - Python 3.x
43
+ - PyTorch
44
+ - Transformers library
45
+ - PEFT library
46
+ - CUDA-capable GPU (recommended)
47
+
48
+ ## Installation
49
+
50
+ Windows
51
+
52
+ ```bash
53
+ git clone https://huggingface.co/John6666/joy-caption-alpha-one-cli-mod
54
+ cd joy-caption-alpha-one-cli-mod
55
+ python -m venv venv
56
+ .\venv\Scripts\activate
57
+ # Change as per https://pytorch.org/get-started/locally/
58
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ Linux
63
+
64
+ ```bash
65
+ git clone https://huggingface.co/John6666/joy-caption-alpha-one-cli-mod
66
+ cd joy-caption-alpha-one-cli-mod
67
+ python3 -m venv venv
68
+ source venv/bin/activate
69
+ pip3 install torch torchvision torchaudio
70
+ pip3 install -r requirements.txt
71
+ ```
72
+
73
+ ## Contributing
74
+
75
+ Contributions are welcome! Please feel free to submit a Pull Request.
76
+
77
+ ## License
78
+
79
+ This project is licensed under the [MIT License](LICENSE).
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.amp.autocast_mode
3
+ import os
4
+ import sys
5
+ import logging
6
+ import warnings
7
+ import argparse
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ from torch import nn
12
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
13
+ from typing import List, Union
14
+ import torchvision.transforms.functional as TVF
15
+ from peft import PeftConfig
16
+ import gc
17
+
18
+ # Constants
19
+ BASE_DIR = Path(__file__).resolve().parent # Define the base directory
20
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
21
+ DEFAULT_MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"
22
+ #DEFAULT_MODEL_PATH = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2" # Works better but full weight.
23
+ CHECKPOINT_PATH = BASE_DIR / Path("9em124t2-499968")
24
+ LORA_PATH = CHECKPOINT_PATH / "text_model"
25
+ CAPTION_TYPE_MAP = {
26
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
27
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
28
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
29
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
30
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
31
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
32
+
33
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
34
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
35
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
36
+
37
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
38
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
39
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
40
+ }
41
+ IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
42
+
43
+ # Global Variables
44
+ IS_NF4 = True
45
+ MODEL_PATH = DEFAULT_MODEL_PATH
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ print(f"Running on {device}")
48
+
49
+ warnings.filterwarnings("ignore", category=UserWarning)
50
+ logging.getLogger("transformers").setLevel(logging.ERROR)
51
+
52
+ class ImageAdapter(nn.Module):
53
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
54
+ super().__init__()
55
+ self.deep_extract = deep_extract
56
+
57
+ if self.deep_extract:
58
+ input_features = input_features * 5
59
+
60
+ self.linear1 = nn.Linear(input_features, output_features)
61
+ self.activation = nn.GELU()
62
+ self.linear2 = nn.Linear(output_features, output_features)
63
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
64
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
65
+
66
+ # Mode token
67
+ #self.mode_token = nn.Embedding(n_modes, output_features)
68
+ #self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
69
+
70
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
71
+ self.other_tokens = nn.Embedding(3, output_features)
72
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
73
+
74
+ def forward(self, vision_outputs: torch.Tensor):
75
+ if self.deep_extract:
76
+ x = torch.concat((
77
+ vision_outputs[-2],
78
+ vision_outputs[3],
79
+ vision_outputs[7],
80
+ vision_outputs[13],
81
+ vision_outputs[20],
82
+ ), dim=-1)
83
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
84
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
85
+ else:
86
+ x = vision_outputs[-2]
87
+
88
+ x = self.ln1(x)
89
+
90
+ if self.pos_emb is not None:
91
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
92
+ x = x + self.pos_emb
93
+
94
+ x = self.linear1(x)
95
+ x = self.activation(x)
96
+ x = self.linear2(x)
97
+
98
+ # Mode token
99
+ #mode_token = self.mode_token(mode)
100
+ #assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
101
+ #x = torch.cat((x, mode_token), dim=1)
102
+
103
+ # <|image_start|>, IMAGE, <|image_end|>
104
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
105
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
106
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
107
+
108
+ return x
109
+
110
+ def get_eot_embedding(self):
111
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
112
+
113
+ def load_models():
114
+ global MODEL_PATH, IS_NF4
115
+ try:
116
+ if IS_NF4:
117
+ from transformers import BitsAndBytesConfig
118
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
119
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
120
+ print("Loading in NF4")
121
+ print("Loading CLIP ๐Ÿ“Ž")
122
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
123
+ clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
124
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
125
+ print("Loading VLM's custom vision model ๐Ÿ“Ž")
126
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
127
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
128
+ clip_model.load_state_dict(checkpoint)
129
+ del checkpoint
130
+ clip_model.eval().requires_grad_(False).to(device)
131
+
132
+ print("Loading tokenizer ๐Ÿช™")
133
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
134
+ assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
135
+
136
+ print(f"Loading LLM: {MODEL_PATH} ๐Ÿค–")
137
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
138
+
139
+ if LORA_PATH.exists():
140
+ print("Loading VLM's custom text model ๐Ÿค–")
141
+ peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device, quantization_config=nf4_config)
142
+ text_model.add_adapter(peft_config)
143
+ text_model.enable_adapters()
144
+
145
+ print("Loading image adapter ๐Ÿ–ผ๏ธ")
146
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
147
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=False))
148
+ image_adapter.eval().to(device)
149
+ else:
150
+ print("Loading in bfloat16")
151
+ print("Loading CLIP ๐Ÿ“Ž")
152
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
153
+ clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
154
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
155
+ print("Loading VLM's custom vision model ๐Ÿ“Ž")
156
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
157
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
158
+ clip_model.load_state_dict(checkpoint)
159
+ del checkpoint
160
+ clip_model.eval().requires_grad_(False).to(device)
161
+
162
+ print("Loading tokenizer ๐Ÿช™")
163
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
164
+ assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
165
+
166
+ print(f"Loading LLM: {MODEL_PATH} ๐Ÿค–")
167
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=device, torch_dtype=torch.bfloat16).eval() # device_map=auto may cause LoRA error
168
+
169
+ if LORA_PATH.exists():
170
+ print("Loading VLM's custom text model ๐Ÿค–")
171
+ peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device)
172
+ text_model.add_adapter(peft_config)
173
+ text_model.enable_adapters()
174
+
175
+ print("Loading image adapter ๐Ÿ–ผ๏ธ")
176
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
177
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=False))
178
+ except Exception as e:
179
+ print(f"Error loading models: {e}")
180
+ sys.exit(1)
181
+ finally:
182
+ torch.cuda.empty_cache()
183
+ gc.collect()
184
+ return clip_processor, clip_model, tokenizer, text_model, image_adapter
185
+
186
+ @torch.inference_mode()
187
+ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_tone: str, caption_length: Union[str, int],
188
+ max_new_tokens: int, top_p: float, temperature: float, batch_size: int, pbar: tqdm, models: tuple) -> List[str]:
189
+ global MODEL_PATH
190
+ clip_processor, clip_model, tokenizer, text_model, image_adapter = models
191
+ torch.cuda.empty_cache()
192
+ all_captions = []
193
+
194
+ # 'any' means no length specified
195
+ length = None if caption_length == "any" else caption_length
196
+
197
+ if isinstance(length, str):
198
+ try:
199
+ length = int(length)
200
+ except ValueError:
201
+ pass
202
+
203
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
204
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
205
+ caption_tone = "formal"
206
+
207
+ # Build prompt
208
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
209
+ if prompt_key not in CAPTION_TYPE_MAP:
210
+ raise ValueError(f"Invalid caption type: {prompt_key}")
211
+
212
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
213
+ print(f"Prompt: {prompt_str}")
214
+
215
+ for i in range(0, len(input_images), batch_size):
216
+ batch = input_images[i:i+batch_size]
217
+ # Preprocess image
218
+ try:
219
+ all_images = []
220
+ for input_image in batch:
221
+ image = input_image.resize((384, 384), Image.LANCZOS)
222
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
223
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
224
+ all_images.append(TVF.to_pil_image(pixel_values.squeeze()))
225
+ batch_pixel_values = clip_processor(images=all_images, return_tensors='pt', padding=True).pixel_values.to(device)
226
+ except ValueError as e:
227
+ print(f"Error processing image batch: {e}")
228
+ print("Skipping this batch and continuing...")
229
+ continue
230
+
231
+ # Embed image
232
+ with torch.amp.autocast_mode.autocast(device, enabled=True):
233
+ vision_outputs = clip_model(pixel_values=batch_pixel_values, output_hidden_states=True)
234
+ image_features = vision_outputs.hidden_states
235
+ embedded_images = image_adapter(image_features).to(device)
236
+
237
+ # Tokenize the prompt
238
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
239
+
240
+ # Embed prompt
241
+ prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
242
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
243
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
244
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
245
+
246
+ # Construct prompts
247
+ inputs_embeds = torch.cat([
248
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
249
+ embedded_images.to(dtype=embedded_bos.dtype),
250
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
251
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
252
+ ], dim=1)
253
+
254
+ input_ids = torch.cat([
255
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
256
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
257
+ prompt,
258
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
259
+ ], dim=1).to(device)
260
+ attention_mask = torch.ones_like(input_ids)
261
+
262
+ generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=True,
263
+ suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
264
+
265
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
266
+
267
+ for ids in generate_ids:
268
+ caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
269
+ caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
270
+ all_captions.append(caption)
271
+
272
+ if pbar:
273
+ pbar.update(len(batch))
274
+
275
+ return all_captions
276
+
277
+ def process_directory(input_dir: Path, output_dir: Path, caption_type: str, caption_tone: str, caption_length: Union[str, int],
278
+ max_new_tokens: int, top_p: float, temperature: float, batch_size: int, models: tuple):
279
+ output_dir.mkdir(parents=True, exist_ok=True)
280
+ image_files = [f for f in input_dir.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
281
+ images_to_process = [f for f in image_files if not (output_dir / f"{f.stem}.txt").exists()]
282
+
283
+ if not images_to_process:
284
+ print("No new images to process.")
285
+ return
286
+
287
+ with tqdm(total=len(images_to_process), desc="Processing images", unit="image") as pbar:
288
+ for i in range(0, len(images_to_process), batch_size):
289
+ batch_files = images_to_process[i:i+batch_size]
290
+ batch_images = [Image.open(f).convert('RGB') for f in batch_files]
291
+
292
+ captions = stream_chat(batch_images, caption_type, caption_tone, caption_length,
293
+ max_new_tokens, top_p, temperature, batch_size, pbar, models)
294
+
295
+ for file, caption in zip(batch_files, captions):
296
+ with open(output_dir / f"{file.stem}.txt", 'w', encoding='utf-8') as f:
297
+ f.write(caption)
298
+
299
+ for img in batch_images:
300
+ img.close()
301
+
302
+ def parse_arguments():
303
+ parser = argparse.ArgumentParser(description="Process images and generate captions.")
304
+ parser.add_argument("input", nargs='+', help="Input image file or directory (or multiple directories)")
305
+ parser.add_argument("--output", help="Output directory (optional)")
306
+ parser.add_argument("--bs", type=int, default=4, help="Batch size (default: 4)")
307
+ parser.add_argument("--type", type=str, default="descriptive", choices=["descriptive", "training_prompt", "rng-tags"],
308
+ help='Caption Type (default: "descriptive")')
309
+ parser.add_argument("--tone", type=str, default="formal", choices=["formal", "informal"],
310
+ help='Caption Tone (default: "formal")')
311
+ parser.add_argument("--len", default="any",
312
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)],
313
+ help='Caption Length (default: "any")')
314
+ parser.add_argument("--model", type=str, default=DEFAULT_MODEL_PATH,
315
+ help='Huggingface LLM repo (default: "unsloth/Meta-Llama-3.1-8B-bnb-4bit")')
316
+ parser.add_argument("--bf16", action="store_true", help="Use bfloat16 (default: NF4)")
317
+ parser.add_argument("--tokens", type=int, default=300, help="Max tokens (default: 300)")
318
+ parser.add_argument("--topp", type=float, default=0.9, help="Top-P (default: 0.9)")
319
+ parser.add_argument("--temp", type=float, default=0.6, help="Temperature (default: 0.6)")
320
+ return parser.parse_args()
321
+
322
+ def is_valid_repo(repo_id):
323
+ from huggingface_hub import HfApi
324
+ import re
325
+ try:
326
+ if not re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', repo_id): return False
327
+ api = HfApi()
328
+ if api.repo_exists(repo_id=repo_id): return True
329
+ else: return False
330
+ except Exception as e:
331
+ print(f"Failed to connect {repo_id}. {e}")
332
+ return False
333
+
334
+ def main():
335
+ global MODEL_PATH, IS_NF4
336
+ args = parse_arguments()
337
+ input_paths = [Path(input_path) for input_path in args.input]
338
+ batch_size = args.bs
339
+ caption_type = args.type
340
+ caption_tone = args.tone
341
+ caption_length = args.len
342
+ max_new_tokens = args.tokens
343
+ top_p = args.topp
344
+ temperature = args.temp
345
+ if args.bf16: IS_NF4 = False
346
+ else: IS_NF4 = True
347
+ if is_valid_repo(args.model): MODEL_PATH = args.model
348
+ else: sys.exit(1)
349
+ models = load_models()
350
+
351
+ for input_path in input_paths:
352
+ if input_path.is_file() and input_path.suffix.lower() in IMAGE_EXTENSIONS:
353
+ output_path = input_path.with_suffix('.txt')
354
+ print(f"Processing single image ๐ŸŽž๏ธ: {input_path.name}")
355
+ with tqdm(total=1, desc="Processing image", unit="image") as pbar:
356
+ captions = stream_chat([Image.open(input_path).convert('RGB')], caption_type, caption_tone, caption_length,
357
+ max_new_tokens, top_p, temperature, 1, pbar, models)
358
+ with open(output_path, 'w', encoding='utf-8') as f:
359
+ f.write(captions[0])
360
+ print(f"Output saved to {output_path}")
361
+ elif input_path.is_dir():
362
+ output_path = Path(args.output) if args.output else input_path
363
+ print(f"Processing directory ๐Ÿ“: {input_path}")
364
+ print(f"Output directory ๐Ÿ“ฆ: {output_path}")
365
+ print(f"Batch size ๐Ÿ—„๏ธ: {batch_size}")
366
+ process_directory(input_path, output_path, caption_type, caption_tone, caption_length,
367
+ max_new_tokens, top_p, temperature, batch_size, models)
368
+ else:
369
+ print(f"Invalid input: {input_path}")
370
+ print("Skipping...")
371
+
372
+ if not input_paths:
373
+ print("Usage:")
374
+ print("For single image: python app.py [image_file] [--bs batch_size]")
375
+ print("For directory (same input/output): python app.py [directory] [--bs batch_size]")
376
+ print("For directory (separate input/output): python app.py [directory] --output [output_directory] [--bs batch_size]")
377
+ print("For multiple directories: python app.py [directory1] [directory2] ... [--output output_directory] [--bs batch_size]")
378
+ sys.exit(1)
379
+
380
+ if __name__ == "__main__":
381
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub>=0.23.4
2
+ accelerate
3
+ torch
4
+ transformers==4.44.0
5
+ sentencepiece
6
+ bitsandbytes
7
+ Pillow
8
+ protobuf
9
+ peft==0.12.0
10
+ torchvision