fancyfeast commited on
Commit
9d2cfda
·
1 Parent(s): fd8406c

Initial commit

Browse files
9em124t2-499968/clip_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d7b0548d12fa649370896982c2af9d03d43285b782bd47639c96e6e0b29473c
3
+ size 1713067838
9em124t2-499968/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_project: joy-caption-1
2
+ device_batch_size: 2
3
+ batch_size: 256
4
+ learning_rate: 0.0002
5
+ warmup_samples: 18000
6
+ max_samples: 500000
7
+ save_every: 50000
8
+ test_every: 50000
9
+ use_amp: true
10
+ grad_scaler: true
11
+ lr_scheduler_type: cosine
12
+ min_lr_ratio: 0.0
13
+ allow_tf32: true
14
+ seed: 69
15
+ num_workers: 8
16
+ optimizer_type: adamw
17
+ adam_beta1: 0.9
18
+ adam_beta2: 0.999
19
+ adam_eps: 1.0e-08
20
+ adam_weight_decay: 0.0
21
+ clip_grad_norm: 1.0
22
+ dataset: fancyfeast/joy-captioning-20240917a
23
+ clip_model: google/siglip-so400m-patch14-384
24
+ text_model: meta-llama/Meta-Llama-3.1-8B
25
+ resume: null
26
+ gradient_checkpointing: false
27
+ test_size: 2048
28
+ grad_scaler_init: 65536.0
29
+ max_caption_length: 257
30
+ num_image_tokens: 32
31
+ adapter_type: mlp
32
+ text_model_dtype: bfloat16
33
+ pre_test: false
34
+ train_image_model: true
35
+ image_model_lr: null
36
+ train_lora: true
37
+ lora_r: 64
38
+ lora_alpha: 16
39
+ lora_dropout: 0.1
9em124t2-499968/image_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e53c3bf8df745a3c19ae3c70dbf9bf23cfdc8f3fdb937000a4eafd2a36914661
3
+ size 86067714
9em124t2-499968/text_model/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Meta-Llama-3.1-8B
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.12.0
9em124t2-499968/text_model/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Meta-Llama-3.1-8B",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 16,
14
+ "lora_dropout": 0.1,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 64,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "q_proj",
24
+ "v_proj"
25
+ ],
26
+ "task_type": "CAUSAL_LM",
27
+ "use_dora": false,
28
+ "use_rslora": false
29
+ }
9em124t2-499968/text_model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b48221de174ab0db7b46b4833118c5c0a4c2bf0b51b77b4cc4ab04651bd06cca
3
+ size 109069176
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
+ from torch import nn
5
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
6
+ from pathlib import Path
7
+ import torch
8
+ import torch.amp.autocast_mode
9
+ from PIL import Image
10
+ import os
11
+ import torchvision.transforms.functional as TVF
12
+
13
+
14
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
15
+ MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
16
+ CHECKPOINT_PATH = Path("9em124t2-499968")
17
+ TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
18
+ CAPTION_TYPE_MAP = {
19
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
20
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
21
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
22
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
23
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
24
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
25
+
26
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
27
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
28
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
29
+
30
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
31
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
32
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
33
+ }
34
+
35
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
36
+
37
+
38
+ class ImageAdapter(nn.Module):
39
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
40
+ super().__init__()
41
+ self.deep_extract = deep_extract
42
+
43
+ if self.deep_extract:
44
+ input_features = input_features * 5
45
+
46
+ self.linear1 = nn.Linear(input_features, output_features)
47
+ self.activation = nn.GELU()
48
+ self.linear2 = nn.Linear(output_features, output_features)
49
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
50
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
51
+
52
+ # Mode token
53
+ #self.mode_token = nn.Embedding(n_modes, output_features)
54
+ #self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
55
+
56
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
57
+ self.other_tokens = nn.Embedding(3, output_features)
58
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
59
+
60
+ def forward(self, vision_outputs: torch.Tensor):
61
+ if self.deep_extract:
62
+ x = torch.concat((
63
+ vision_outputs[-2],
64
+ vision_outputs[3],
65
+ vision_outputs[7],
66
+ vision_outputs[13],
67
+ vision_outputs[20],
68
+ ), dim=-1)
69
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
70
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
71
+ else:
72
+ x = vision_outputs[-2]
73
+
74
+ x = self.ln1(x)
75
+
76
+ if self.pos_emb is not None:
77
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
78
+ x = x + self.pos_emb
79
+
80
+ x = self.linear1(x)
81
+ x = self.activation(x)
82
+ x = self.linear2(x)
83
+
84
+ # Mode token
85
+ #mode_token = self.mode_token(mode)
86
+ #assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
87
+ #x = torch.cat((x, mode_token), dim=1)
88
+
89
+ # <|image_start|>, IMAGE, <|image_end|>
90
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
91
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
92
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
93
+
94
+ return x
95
+
96
+ def get_eot_embedding(self):
97
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
98
+
99
+
100
+
101
+ # Load CLIP
102
+ print("Loading CLIP")
103
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
104
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
105
+ clip_model = clip_model.vision_model
106
+
107
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
108
+ print("Loading VLM's custom vision model")
109
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
110
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
111
+ clip_model.load_state_dict(checkpoint)
112
+ del checkpoint
113
+
114
+ clip_model.eval()
115
+ clip_model.requires_grad_(False)
116
+ clip_model.to("cuda")
117
+
118
+
119
+ # Tokenizer
120
+ print("Loading tokenizer")
121
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
122
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
123
+
124
+ # LLM
125
+ print("Loading LLM")
126
+ if (CHECKPOINT_PATH / "text_model").exists:
127
+ print("Loading VLM's custom text model")
128
+ text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
129
+ else:
130
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
131
+
132
+ text_model.eval()
133
+
134
+ # Image Adapter
135
+ print("Loading image adapter")
136
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
137
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
138
+ image_adapter.eval()
139
+ image_adapter.to("cuda")
140
+
141
+
142
+ @spaces.GPU()
143
+ @torch.no_grad()
144
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
145
+ torch.cuda.empty_cache()
146
+
147
+ length = None if caption_length == "any" else caption_length
148
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
149
+ if prompt_key not in CAPTION_TYPE_MAP:
150
+ raise ValueError(f"Invalid caption type: {prompt_key}")
151
+
152
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0]
153
+
154
+ # Preprocess image
155
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
156
+ image = input_image.resize((384, 384), Image.LANCZOS)
157
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
158
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
159
+ pixel_values = pixel_values.to('cuda')
160
+
161
+ # Tokenize the prompt
162
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
163
+
164
+ # Embed image
165
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
166
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
167
+ image_features = vision_outputs.hidden_states
168
+ embedded_images = image_adapter(image_features)
169
+ embedded_images = embedded_images.to('cuda')
170
+
171
+ # Embed prompt
172
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
173
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
174
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
175
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
176
+
177
+ # Construct prompts
178
+ inputs_embeds = torch.cat([
179
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
180
+ embedded_images.to(dtype=embedded_bos.dtype),
181
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
182
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
183
+ ], dim=1)
184
+
185
+ input_ids = torch.cat([
186
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
187
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
188
+ prompt,
189
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
190
+ ], dim=1).to('cuda')
191
+ attention_mask = torch.ones_like(input_ids)
192
+
193
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
194
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
195
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
196
+
197
+ # Trim off the prompt
198
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
199
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
200
+ generate_ids = generate_ids[:, :-1]
201
+
202
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
203
+
204
+ return caption.strip()
205
+
206
+
207
+ with gr.Blocks() as demo:
208
+ gr.HTML(TITLE)
209
+
210
+ with gr.Row():
211
+ with gr.Column():
212
+ input_image = gr.Image(type="pil", label="Input Image")
213
+
214
+ caption_type = gr.Dropdown(
215
+ choices=["descriptive", "training_prompt", "rng-tags"],
216
+ label="Caption Type",
217
+ value="descriptive",
218
+ )
219
+
220
+ caption_tone = gr.Dropdown(
221
+ choices=["formal", "informal"],
222
+ label="Caption Tone",
223
+ value="formal",
224
+ )
225
+
226
+ caption_length = gr.Dropdown(
227
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] +
228
+ [str(i) for i in range(20, 261, 10)],
229
+ label="Caption Length",
230
+ value="any",
231
+ )
232
+
233
+ run_button = gr.Button("Caption")
234
+
235
+ with gr.Column():
236
+ output_caption = gr.Textbox(label="Caption")
237
+
238
+ run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length], outputs=[output_caption])
239
+
240
+
241
+ if __name__ == "__main__":
242
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub==0.23.4
2
+ accelerate
3
+ torch
4
+ transformers==4.44.0
5
+ sentencepiece
6
+ peft==0.12.0