John6666 commited on
Commit
f902fc6
1 Parent(s): 1650be4

Upload 6 files

Browse files
Files changed (5) hide show
  1. app.py +11 -5
  2. joycaption.py +82 -26
  3. packages.txt +1 -0
  4. pre-requirements.txt +1 -0
  5. requirements.txt +4 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import spaces
2
  import gradio as gr
3
- from joycaption import stream_chat_mod, get_text_model, change_text_model
4
 
5
  JC_TITLE_MD = "<h1><center>JoyCaption Pre-Alpha Mod</center></h1>"
6
  JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha),
@@ -17,9 +17,14 @@ with gr.Blocks() as demo:
17
  with gr.Group():
18
  jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384)
19
  with gr.Accordion("Advanced", open=False):
20
- jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a huggingface model repo_id to want to use.",
21
- choices=get_text_model(), value=get_text_model()[0],
22
- allow_custom_value=True, interactive=True, min_width=320)
 
 
 
 
 
23
  jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False)
24
  with gr.Row():
25
  jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens")
@@ -32,7 +37,8 @@ with gr.Blocks() as demo:
32
  gr.Markdown(JC_DESC_MD, elem_classes="info")
33
 
34
  jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_tokens, jc_topk, jc_temperature], outputs=[jc_output_caption])
35
- jc_text_model.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
 
36
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
37
 
38
  if __name__ == "__main__":
 
1
  import spaces
2
  import gradio as gr
3
+ from joycaption import stream_chat_mod, get_text_model, change_text_model, get_repo_gguf
4
 
5
  JC_TITLE_MD = "<h1><center>JoyCaption Pre-Alpha Mod</center></h1>"
6
  JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha),
 
17
  with gr.Group():
18
  jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384)
19
  with gr.Accordion("Advanced", open=False):
20
+ with gr.Row():
21
+ jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a huggingface model repo_id to want to use.",
22
+ choices=get_text_model(), value=get_text_model()[0],
23
+ allow_custom_value=True, interactive=True, min_width=320)
24
+ jc_gguf = gr.Dropdown(label=f"GGUF Filename", choices=[], value="",
25
+ allow_custom_value=True, min_width=320, visible=False)
26
+ jc_nf4 = gr.Checkbox(label="Use NF4 quantization", value=True)
27
+ jc_text_model_button = gr.Button("Load Model", variant="secondary")
28
  jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False)
29
  with gr.Row():
30
  jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens")
 
37
  gr.Markdown(JC_DESC_MD, elem_classes="info")
38
 
39
  jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_tokens, jc_topk, jc_temperature], outputs=[jc_output_caption])
40
+ jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
41
+ #jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
42
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
43
 
44
  if __name__ == "__main__":
joycaption.py CHANGED
@@ -12,17 +12,16 @@ import gc
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- llm_models = [
16
- "Sao10K/Llama-3.1-8B-Stheno-v3.4",
17
- "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
18
- "mergekit-community/L3.1-Boshima-b-FIX",
19
- "meta-llama/Meta-Llama-3.1-8B",
20
- ]
21
-
22
 
23
  CLIP_PATH = "google/siglip-so400m-patch14-384"
24
  VLM_PROMPT = "A descriptive caption for this image:\n"
25
- MODEL_PATH = llm_models[0]
26
  CHECKPOINT_PATH = Path("wpkklhc6")
27
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
28
 
@@ -42,21 +41,41 @@ class ImageAdapter(nn.Module):
42
  x = self.linear2(x)
43
  return x
44
 
 
 
 
 
 
45
  # https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
46
  # https://huggingface.co/google/flan-ul2/discussions/8
47
-
 
48
  text_model_client = None
49
  text_model = None
50
  image_adapter = None
51
- def load_text_model(model_name: str=MODEL_PATH):
 
52
  global text_model
53
  global image_adapter
54
- global text_model_client
55
- global use_inference_client
56
  try:
 
 
 
 
 
 
 
57
  print(f"Loading LLM: {model_name}")
58
- if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
59
- else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
 
 
 
 
 
 
60
  print("Loading image adapter")
61
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size).eval().to("cpu")
62
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
@@ -76,10 +95,6 @@ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
76
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model.eval().requires_grad_(False).to(device)
77
 
78
  # Tokenizer
79
- print("Loading tokenizer")
80
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
81
- assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
82
-
83
  # LLM
84
  # Image Adapter
85
  load_text_model()
@@ -176,11 +191,17 @@ def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: in
176
  ], dim=1).to(device)
177
  attention_mask = torch.ones_like(input_ids)
178
 
 
 
 
 
179
  # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
180
  # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
181
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
182
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
183
  max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
 
 
184
 
185
  # Trim off the prompt
186
  generate_ids = generate_ids[:, input_ids.shape[1]:]
@@ -199,8 +220,8 @@ def is_repo_name(s):
199
 
200
  def is_repo_exists(repo_id):
201
  from huggingface_hub import HfApi
202
- api = HfApi()
203
  try:
 
204
  if api.repo_exists(repo_id=repo_id): return True
205
  else: return False
206
  except Exception as e:
@@ -210,24 +231,59 @@ def is_repo_exists(repo_id):
210
 
211
 
212
  def get_text_model():
213
- return llm_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
 
216
  @spaces.GPU()
217
- def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, progress=gr.Progress(track_tqdm=True)):
 
218
  global use_inference_client
219
- global text_model
220
  global llm_models
221
  use_inference_client = use_client
222
  try:
223
  if not is_repo_name(model_name) or not is_repo_exists(model_name):
224
  raise gr.Error(f"Repo doesn't exist: {model_name}")
 
 
 
225
  if use_inference_client:
226
- pass
227
  else:
228
- load_text_model(model_name)
229
- if model_name not in llm_models: llm_models.append(model_name)
230
- return gr.update(visible=True)
231
  except Exception as e:
232
  raise gr.Error(f"Model load error: {model_name}, {e}")
233
 
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ llm_models = {
16
+ "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
17
+ "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
18
+ "mergekit-community/L3.1-Boshima-b-FIX": None,
19
+ "meta-llama/Meta-Llama-3.1-8B": None,
20
+ }
 
21
 
22
  CLIP_PATH = "google/siglip-so400m-patch14-384"
23
  VLM_PROMPT = "A descriptive caption for this image:\n"
24
+ MODEL_PATH = list(llm_models.keys())[0]
25
  CHECKPOINT_PATH = Path("wpkklhc6")
26
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
27
 
 
41
  x = self.linear2(x)
42
  return x
43
 
44
+ # https://huggingface.co/docs/transformers/v4.44.2/gguf
45
+ # https://github.com/city96/ComfyUI-GGUF/issues/7
46
+ # https://github.com/THUDM/ChatGLM-6B/issues/18
47
+ # https://github.com/meta-llama/llama/issues/394
48
+ # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/109
49
  # https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
50
  # https://huggingface.co/google/flan-ul2/discussions/8
51
+ # https://huggingface.co/blog/4bit-transformers-bitsandbytes
52
+ tokenizer = None
53
  text_model_client = None
54
  text_model = None
55
  image_adapter = None
56
+ def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True):
57
+ global tokenizer
58
  global text_model
59
  global image_adapter
60
+ global text_model_client #
61
+ global use_inference_client #
62
  try:
63
+ from transformers import BitsAndBytesConfig
64
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
65
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
66
+ print("Loading tokenizer")
67
+ if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
68
+ else: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
69
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
70
  print(f"Loading LLM: {model_name}")
71
+ if gguf_file:
72
+ if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
73
+ elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
74
+ else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
75
+ else:
76
+ if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
77
+ elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
78
+ else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
79
  print("Loading image adapter")
80
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size).eval().to("cpu")
81
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
 
95
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model.eval().requires_grad_(False).to(device)
96
 
97
  # Tokenizer
 
 
 
 
98
  # LLM
99
  # Image Adapter
100
  load_text_model()
 
191
  ], dim=1).to(device)
192
  attention_mask = torch.ones_like(input_ids)
193
 
194
+ # https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
195
+ # https://github.com/huggingface/transformers/issues/6535
196
+ # https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
197
+ # https://github.com/ggerganov/llama.cpp/discussions/7712
198
  # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
199
  # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
200
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
201
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
202
  max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
203
+
204
+ print(prompt)
205
 
206
  # Trim off the prompt
207
  generate_ids = generate_ids[:, input_ids.shape[1]:]
 
220
 
221
  def is_repo_exists(repo_id):
222
  from huggingface_hub import HfApi
 
223
  try:
224
+ api = HfApi(token=HF_TOKEN)
225
  if api.repo_exists(repo_id=repo_id): return True
226
  else: return False
227
  except Exception as e:
 
231
 
232
 
233
  def get_text_model():
234
+ return list(llm_models.keys())
235
+
236
+
237
+ def is_gguf_repo(repo_id: str):
238
+ from huggingface_hub import HfApi
239
+ try:
240
+ api = HfApi(token=HF_TOKEN)
241
+ if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return False
242
+ files = api.list_repo_files(repo_id=repo_id)
243
+ except Exception as e:
244
+ print(f"Error: Failed to get {repo_id}'s info.")
245
+ print(e)
246
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
247
+ return False
248
+ files = [f for f in files if f.endswith(".gguf")]
249
+ if len(files) == 0: return False
250
+ else: return True
251
+
252
+
253
+ def get_repo_gguf(repo_id: str):
254
+ from huggingface_hub import HfApi
255
+ try:
256
+ api = HfApi(token=HF_TOKEN)
257
+ if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
258
+ files = api.list_repo_files(repo_id=repo_id)
259
+ except Exception as e:
260
+ print(f"Error: Failed to get {repo_id}'s info.")
261
+ print(e)
262
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
263
+ return gr.update(value="", choices=[])
264
+ files = [f for f in files if f.endswith(".gguf")]
265
+ if len(files) == 0: return gr.update(value="", choices=[])
266
+ else: return gr.update(value=files[0], choices=files)
267
 
268
 
269
  @spaces.GPU()
270
+ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: str | None=None,
271
+ is_nf4: bool=True, progress=gr.Progress(track_tqdm=True)):
272
  global use_inference_client
 
273
  global llm_models
274
  use_inference_client = use_client
275
  try:
276
  if not is_repo_name(model_name) or not is_repo_exists(model_name):
277
  raise gr.Error(f"Repo doesn't exist: {model_name}")
278
+ if not gguf_file and is_gguf_repo(model_name):
279
+ gr.Info(f"Please select a gguf file.")
280
+ return gr.update(visible=True)
281
  if use_inference_client:
282
+ pass #
283
  else:
284
+ load_text_model(model_name, gguf_file, is_nf4)
285
+ if model_name not in llm_models: llm_models[model_name] = gguf_file if gguf_file else None
286
+ return gr.update(choices=get_text_model())
287
  except Exception as e:
288
  raise gr.Error(f"Model load error: {model_name}, {e}")
289
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  huggingface_hub
2
  accelerate
3
  torch
4
- transformers==4.43.3
5
  sentencepiece
6
  bitsandbytes
7
  Pillow
8
- protobuf
 
 
 
1
  huggingface_hub
2
  accelerate
3
  torch
4
+ git+https://github.com/huggingface/transformers
5
  sentencepiece
6
  bitsandbytes
7
  Pillow
8
+ protobuf
9
+ gguf
10
+ numpy<2.0.0