VikramSingh178 commited on
Commit
ad35dba
1 Parent(s): e9b25d5

refactor: Update image captioning script to use Salesforce/blip-image-captioning-base model

Browse files
scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/config.py CHANGED
@@ -4,7 +4,7 @@ VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
6
  PRODUCTS_10k_DATASET = "amaye15/Products-10k"
7
- CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-large"
8
 
9
 
10
 
 
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
6
  PRODUCTS_10k_DATASET = "amaye15/Products-10k"
7
+ CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
8
 
9
 
10
 
scripts/products10k_captions.py CHANGED
@@ -1,9 +1,10 @@
1
- from datasets import load_dataset
2
- from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from tqdm import tqdm
 
5
  import torch
6
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  class ImageCaptioner:
@@ -14,11 +15,13 @@ class ImageCaptioner:
14
  dataset (str): The path to the dataset.
15
  processor (str): The pre-trained processor model to use for image processing.
16
  model (str): The pre-trained model to use for caption generation.
 
17
 
18
  Attributes:
19
  dataset: The loaded dataset.
20
  processor: The pre-trained processor model.
21
  model: The pre-trained caption generation model.
 
22
 
23
  Methods:
24
  process_dataset: Preprocesses the dataset.
@@ -26,10 +29,11 @@ class ImageCaptioner:
26
 
27
  """
28
 
29
- def __init__(self, dataset: str, processor: str, model: str):
30
  self.dataset = load_dataset(dataset, split="train")
31
  self.processor = BlipProcessor.from_pretrained(processor)
32
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
 
33
 
34
  def process_dataset(self):
35
  """
@@ -57,19 +61,25 @@ class ImageCaptioner:
57
  for idx in tqdm(range(len(self.dataset))):
58
  image = self.dataset[idx]["image"].convert("RGB")
59
  inputs = self.processor(images=image, return_tensors="pt").to(device)
60
- outputs = self.model.generate(**inputs)
 
61
  blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
62
- self.dataset[idx]["caption"] = blip_caption
63
- print(f"Caption for image {idx}: {blip_caption}")
64
 
65
- # Optionally, you can save the dataset with captions to disk
66
- # self.dataset.save_to_disk('path_to_save_dataset')
67
 
68
  return self.dataset
69
 
 
70
  ic = ImageCaptioner(
71
  dataset=PRODUCTS_10k_DATASET,
72
  processor=CAPTIONING_MODEL_NAME,
73
  model=CAPTIONING_MODEL_NAME,
 
74
  )
75
- ic.generate_captions()
 
 
 
 
 
1
+ from datasets import load_dataset, Dataset
 
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from tqdm import tqdm
4
+ from config import PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME
5
  import torch
6
 
7
+
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  class ImageCaptioner:
 
15
  dataset (str): The path to the dataset.
16
  processor (str): The pre-trained processor model to use for image processing.
17
  model (str): The pre-trained model to use for caption generation.
18
+ prompt (str): The conditioning prompt to use for caption generation.
19
 
20
  Attributes:
21
  dataset: The loaded dataset.
22
  processor: The pre-trained processor model.
23
  model: The pre-trained caption generation model.
24
+ prompt: The conditioning prompt for generating captions.
25
 
26
  Methods:
27
  process_dataset: Preprocesses the dataset.
 
29
 
30
  """
31
 
32
+ def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
33
  self.dataset = load_dataset(dataset, split="train")
34
  self.processor = BlipProcessor.from_pretrained(processor)
35
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
36
+ self.prompt = prompt
37
 
38
  def process_dataset(self):
39
  """
 
61
  for idx in tqdm(range(len(self.dataset))):
62
  image = self.dataset[idx]["image"].convert("RGB")
63
  inputs = self.processor(images=image, return_tensors="pt").to(device)
64
+ prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
65
+ outputs = self.model.generate(**inputs, **prompt_inputs)
66
  blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
67
+ self.dataset[idx]["text"] = blip_caption
68
+
69
 
70
+
 
71
 
72
  return self.dataset
73
 
74
+ # Initialize ImageCaptioner
75
  ic = ImageCaptioner(
76
  dataset=PRODUCTS_10k_DATASET,
77
  processor=CAPTIONING_MODEL_NAME,
78
  model=CAPTIONING_MODEL_NAME,
79
+ prompt='Photography of ' # Adding the conditioning prompt
80
  )
81
+
82
+ # Generate captions for the dataset
83
+ products10k_dataset = ic.generate_captions()
84
+ new_dataset = Dataset.from_pandas(products10k_dataset.to_pandas()) # Convert to a `datasets` Dataset if necessary
85
+ new_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")