DucHaiten commited on
Commit
3734881
·
verified ·
1 Parent(s): 9f86577

Update image_to_caption.py

Browse files
Files changed (1) hide show
  1. image_to_caption.py +72 -42
image_to_caption.py CHANGED
@@ -10,6 +10,7 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer
10
  import json
11
  import traceback
12
  import math
 
13
 
14
  torch.set_grad_enabled(False)
15
 
@@ -125,7 +126,7 @@ def update_and_save_config():
125
  'temperature': temperature_var.get(),
126
  'top_k': top_k_var.get(),
127
  'top_p': float(top_p_value) if top_p_value is not None else None,
128
- 'bit_precision': bit_precision_var.get(), # Hợp nhất cả precision và bit
129
  'thread_count': thread_count_var.get(),
130
  'batch_size': batch_size_var.get(),
131
  'prepend_text': prepend_text_var.get(),
@@ -150,7 +151,7 @@ def load_config_from_json():
150
  top_k_var.set(config_entry.get('top_k', 50))
151
  top_p_var.set(config_entry.get('top_p', 0.95))
152
  bit_precision_var.set(config_entry.get('bit_precision', 8)) # Tải bit_precision
153
- thread_count_var.set(config_entry.get('thread_count', 4))
154
  batch_size_var.set(config_entry.get('batch_size', 1))
155
  prepend_text_var.set(config_entry.get('prepend_text', ''))
156
  append_text_var.set(config_entry.get('append_text', ''))
@@ -290,7 +291,7 @@ def open_image_to_caption():
290
  temperature_var = tk.DoubleVar(value=1.0)
291
  top_k_var = tk.IntVar(value=50)
292
  top_p_var = tk.DoubleVar(value=0.95)
293
- thread_count_var = tk.IntVar(value=4)
294
  precision_var = tk.IntVar(value=1)
295
  batch_size_var = tk.IntVar(value=1)
296
  prepend_text_var = tk.StringVar()
@@ -482,7 +483,7 @@ def generate_caption(image_path, save_directory, q):
482
  load_model()
483
 
484
  filename = os.path.splitext(os.path.basename(image_path))[0]
485
- caption_file_path = os.path.join(save_directory, f"{filename}.txt")
486
 
487
  # Kiểm tra các lựa chọn của người dùng
488
  if os.path.exists(caption_file_path):
@@ -497,10 +498,21 @@ def generate_caption(image_path, save_directory, q):
497
  else:
498
  existing_caption = ""
499
 
 
500
  image = PILImage.open(image_path).convert('RGB')
501
  if not isinstance(image, PILImage.Image):
502
  raise ValueError(f"Expected image to be of type PIL.Image.Image, but got {type(image)}")
503
 
 
 
 
 
 
 
 
 
 
 
504
  inputs = model.build_conversation_input_ids(
505
  tokenizer,
506
  query=prompt_var.get(),
@@ -510,14 +522,14 @@ def generate_caption(image_path, save_directory, q):
510
 
511
  # Điều chỉnh dtype dựa trên bit_precision
512
  if bit_precision_var.get() == 32:
513
- image_tensor = inputs['images'][0].to('cuda').to(torch.float32)
514
  else:
515
- image_tensor = inputs['images'][0].to('cuda').to(torch.float16)
516
 
517
  inputs = {
518
- 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
519
- 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
520
- 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
521
  'images': [[image_tensor]],
522
  }
523
 
@@ -530,7 +542,8 @@ def generate_caption(image_path, save_directory, q):
530
  "num_beams": precision_var.get()
531
  }
532
 
533
- with torch.no_grad():
 
534
  outputs = model.generate(**inputs, **gen_kwargs)
535
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
536
  new_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -541,7 +554,7 @@ def generate_caption(image_path, save_directory, q):
541
  file.write(final_caption)
542
 
543
  q.put(image_path)
544
- torch.cuda.empty_cache()
545
  except torch.cuda.OutOfMemoryError as e:
546
  torch.cuda.empty_cache()
547
  error_message = f"CUDA OutOfMemoryError: {traceback.format_exc()}"
@@ -553,45 +566,55 @@ def generate_caption(image_path, save_directory, q):
553
  print(error_message)
554
  q.put(error_message)
555
  error_messages.append(error_message)
 
 
 
 
 
556
 
557
 
558
  def worker(save_directory, num_threads, batch_size):
559
  try:
560
  progress.set(0)
561
- threads = []
562
-
563
  num_batches = math.ceil(len(selected_files) / batch_size)
564
- batch_size_per_thread = max(1, batch_size // num_threads) # Số ảnh mỗi luồng xử lý trong một batch
565
-
566
- for batch_index in range(num_batches):
567
- if stop_processing:
568
- break
569
-
570
- start_index = batch_index * batch_size
571
- end_index = min(start_index + batch_size, len(selected_files))
572
- batch = selected_files[start_index:end_index]
573
-
574
- # Chia ảnh trong batch cho các luồng
575
- for i in range(0, len(batch), batch_size_per_thread):
576
- thread_batch = batch[i:i + batch_size_per_thread]
577
- thread = threading.Thread(target=generate_captions_for_batch, args=(thread_batch, save_directory, q))
578
- threads.append(thread)
579
- thread.start()
580
-
581
- # Đợi các luồng trong batch hiện tại hoàn thành
582
- for thread in threads:
583
- thread.join()
584
- threads.clear()
 
 
 
 
 
 
585
 
586
  q.put(None)
587
  except Exception as e:
588
  if not stop_processing:
589
- q.put(e)
590
 
591
  def generate_captions_for_batch(batch, save_directory, q):
592
  for image_path in batch:
593
  generate_caption(image_path, save_directory, q)
594
 
 
595
  def update_progress():
596
  try:
597
  completed = 0
@@ -758,7 +781,8 @@ def update_image_preview(content_canvas):
758
  file_label = tk.Label(caption_frame, text=os.path.basename(file_path), font=('Helvetica', 12), wraplength=300, justify="left")
759
  file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
760
 
761
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
762
  if os.path.exists(caption_file):
763
  with open(caption_file, 'r', encoding='utf-8') as file:
764
  caption_text = file.read()
@@ -817,7 +841,8 @@ def go_to_page(page_number, content_canvas):
817
  messagebox.showerror("Invalid Input", "Please enter a valid integer for the page number.")
818
 
819
  def save_caption(file_path, caption_text):
820
- output_path = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
821
  try:
822
  with open(output_path, 'w', encoding='utf-8') as file:
823
  file.write(caption_text.strip())
@@ -840,7 +865,8 @@ def search_captions():
840
  update_image_preview(content_canvas)
841
 
842
  def search_score(file_path, search_term):
843
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
844
  try:
845
  if os.path.exists(caption_file):
846
  with open(caption_file, 'r', encoding='utf-8') as file:
@@ -866,7 +892,8 @@ def add_to_captions(position):
866
  return
867
 
868
  for file_path in selected_files:
869
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
870
  if os.path.exists(caption_file):
871
  with open(caption_file, 'r+', encoding='utf-8') as file:
872
  caption_text = file.read()
@@ -889,7 +916,8 @@ def delete_keyword_from_captions():
889
  return
890
 
891
  for file_path in selected_files:
892
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
893
  if os.path.exists(caption_file):
894
  with open(caption_file, 'r+', encoding='utf-8') as file:
895
  caption_text = file.read().lower().replace(keyword, "")
@@ -910,7 +938,8 @@ def delete_images_with_keyword():
910
 
911
  files_to_delete = []
912
  for file_path in selected_files:
913
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
914
  if os.path.exists(caption_file):
915
  with open(caption_file, 'r', encoding='utf-8') as file:
916
  caption_text = file.read().lower()
@@ -920,7 +949,8 @@ def delete_images_with_keyword():
920
  for file_path in files_to_delete:
921
  try:
922
  os.remove(file_path)
923
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
 
924
  if os.path.exists(caption_file):
925
  os.remove(caption_file)
926
  except Exception as e:
 
10
  import json
11
  import traceback
12
  import math
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
 
15
  torch.set_grad_enabled(False)
16
 
 
126
  'temperature': temperature_var.get(),
127
  'top_k': top_k_var.get(),
128
  'top_p': float(top_p_value) if top_p_value is not None else None,
129
+ 'bit_precision': bit_precision_var.get(), # Tải bit_precision
130
  'thread_count': thread_count_var.get(),
131
  'batch_size': batch_size_var.get(),
132
  'prepend_text': prepend_text_var.get(),
 
151
  top_k_var.set(config_entry.get('top_k', 50))
152
  top_p_var.set(config_entry.get('top_p', 0.95))
153
  bit_precision_var.set(config_entry.get('bit_precision', 8)) # Tải bit_precision
154
+ thread_count_var.set(config_entry.get('thread_count', 1))
155
  batch_size_var.set(config_entry.get('batch_size', 1))
156
  prepend_text_var.set(config_entry.get('prepend_text', ''))
157
  append_text_var.set(config_entry.get('append_text', ''))
 
291
  temperature_var = tk.DoubleVar(value=1.0)
292
  top_k_var = tk.IntVar(value=50)
293
  top_p_var = tk.DoubleVar(value=0.95)
294
+ thread_count_var = tk.IntVar(value=1)
295
  precision_var = tk.IntVar(value=1)
296
  batch_size_var = tk.IntVar(value=1)
297
  prepend_text_var = tk.StringVar()
 
483
  load_model()
484
 
485
  filename = os.path.splitext(os.path.basename(image_path))[0]
486
+ caption_file_path = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
487
 
488
  # Kiểm tra các lựa chọn của người dùng
489
  if os.path.exists(caption_file_path):
 
498
  else:
499
  existing_caption = ""
500
 
501
+ # Xử lý ảnh trên CPU trước khi chuyển lên GPU
502
  image = PILImage.open(image_path).convert('RGB')
503
  if not isinstance(image, PILImage.Image):
504
  raise ValueError(f"Expected image to be of type PIL.Image.Image, but got {type(image)}")
505
 
506
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
507
+
508
+ # Kiểm tra nếu bit_precision là 4 hoặc 8
509
+ if bit_precision_var.get() in [4, 8]:
510
+ # Không sử dụng `.to()` cho mô hình khi đang ở chế độ 4-bit hoặc 8-bit
511
+ pass
512
+ else:
513
+ model.to(device)
514
+
515
+ # Xử lý dtype và inputs tương ứng
516
  inputs = model.build_conversation_input_ids(
517
  tokenizer,
518
  query=prompt_var.get(),
 
522
 
523
  # Điều chỉnh dtype dựa trên bit_precision
524
  if bit_precision_var.get() == 32:
525
+ image_tensor = inputs['images'][0].to(device).to(torch.float32)
526
  else:
527
+ image_tensor = inputs['images'][0].to(device).to(torch.float16)
528
 
529
  inputs = {
530
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(device),
531
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device),
532
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device),
533
  'images': [[image_tensor]],
534
  }
535
 
 
542
  "num_beams": precision_var.get()
543
  }
544
 
545
+ # Sử dụng torch.amp.autocast để cải thiện hiệu suất trên GPU
546
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16 if bit_precision_var.get() != 32 else torch.float32):
547
  outputs = model.generate(**inputs, **gen_kwargs)
548
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
549
  new_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
554
  file.write(final_caption)
555
 
556
  q.put(image_path)
557
+
558
  except torch.cuda.OutOfMemoryError as e:
559
  torch.cuda.empty_cache()
560
  error_message = f"CUDA OutOfMemoryError: {traceback.format_exc()}"
 
566
  print(error_message)
567
  q.put(error_message)
568
  error_messages.append(error_message)
569
+ finally:
570
+ if stop_processing or bit_precision_var.get() not in [4, 8]:
571
+ model.to('cpu')
572
+ torch.cuda.empty_cache()
573
+
574
 
575
 
576
  def worker(save_directory, num_threads, batch_size):
577
  try:
578
  progress.set(0)
 
 
579
  num_batches = math.ceil(len(selected_files) / batch_size)
580
+ batch_size_per_thread = max(1, batch_size // num_threads)
581
+
582
+ def process_batch(thread_batch):
583
+ generate_captions_for_batch(thread_batch, save_directory, q)
584
+
585
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
586
+ for batch_index in range(num_batches):
587
+ if stop_processing:
588
+ break
589
+
590
+ start_index = batch_index * batch_size
591
+ end_index = min(start_index + batch_size, len(selected_files))
592
+ batch = selected_files[start_index:end_index]
593
+
594
+ futures = []
595
+ for i in range(0, len(batch), batch_size_per_thread):
596
+ thread_batch = batch[i:i + batch_size_per_thread]
597
+ futures.append(executor.submit(process_batch, thread_batch))
598
+
599
+ # Đợi các công việc trong batch hiện tại hoàn thành
600
+ for future in as_completed(futures):
601
+ try:
602
+ future.result() # Xử lý lỗi nếu có xảy ra trong quá trình xử lý batch
603
+ except Exception as e:
604
+ q.put(f"Error processing batch: {e}")
605
+ if stop_processing:
606
+ break
607
 
608
  q.put(None)
609
  except Exception as e:
610
  if not stop_processing:
611
+ q.put(f"Worker encountered an error: {e}")
612
 
613
  def generate_captions_for_batch(batch, save_directory, q):
614
  for image_path in batch:
615
  generate_caption(image_path, save_directory, q)
616
 
617
+
618
  def update_progress():
619
  try:
620
  completed = 0
 
781
  file_label = tk.Label(caption_frame, text=os.path.basename(file_path), font=('Helvetica', 12), wraplength=300, justify="left")
782
  file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
783
 
784
+ filename = os.path.splitext(os.path.basename(file_path))[0]
785
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
786
  if os.path.exists(caption_file):
787
  with open(caption_file, 'r', encoding='utf-8') as file:
788
  caption_text = file.read()
 
841
  messagebox.showerror("Invalid Input", "Please enter a valid integer for the page number.")
842
 
843
  def save_caption(file_path, caption_text):
844
+ filename = os.path.splitext(os.path.basename(file_path))[0]
845
+ output_path = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
846
  try:
847
  with open(output_path, 'w', encoding='utf-8') as file:
848
  file.write(caption_text.strip())
 
865
  update_image_preview(content_canvas)
866
 
867
  def search_score(file_path, search_term):
868
+ filename = os.path.splitext(os.path.basename(file_path))[0]
869
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
870
  try:
871
  if os.path.exists(caption_file):
872
  with open(caption_file, 'r', encoding='utf-8') as file:
 
892
  return
893
 
894
  for file_path in selected_files:
895
+ filename = os.path.splitext(os.path.basename(file_path))[0]
896
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
897
  if os.path.exists(caption_file):
898
  with open(caption_file, 'r+', encoding='utf-8') as file:
899
  caption_text = file.read()
 
916
  return
917
 
918
  for file_path in selected_files:
919
+ filename = os.path.splitext(os.path.basename(file_path))[0]
920
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
921
  if os.path.exists(caption_file):
922
  with open(caption_file, 'r+', encoding='utf-8') as file:
923
  caption_text = file.read().lower().replace(keyword, "")
 
938
 
939
  files_to_delete = []
940
  for file_path in selected_files:
941
+ filename = os.path.splitext(os.path.basename(file_path))[0]
942
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
943
  if os.path.exists(caption_file):
944
  with open(caption_file, 'r', encoding='utf-8') as file:
945
  caption_text = file.read().lower()
 
949
  for file_path in files_to_delete:
950
  try:
951
  os.remove(file_path)
952
+ filename = os.path.splitext(os.path.basename(file_path))[0]
953
+ caption_file = os.path.join(save_directory, f"{filename}.txt") # Thay đổi tên tệp caption
954
  if os.path.exists(caption_file):
955
  os.remove(caption_file)
956
  except Exception as e: