VikramSingh178 commited on
Commit
fc250c3
β€’
1 Parent(s): d2a2d86

Update image captioning script to use Salesforce/blip-image-captioning-large model

Browse files
scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/__pycache__/logger.cpython-310.pyc ADDED
Binary file (919 Bytes). View file
 
scripts/config.py CHANGED
@@ -4,6 +4,8 @@ VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
6
  PRODUCTS_10k_DATASET = "amaye15/Products-10k"
 
 
7
 
8
 
9
  class Config:
 
4
  DATASET_NAME= "hahminlew/kream-product-blip-captions"
5
  PROJECT_NAME = "Product Photography"
6
  PRODUCTS_10k_DATASET = "amaye15/Products-10k"
7
+ CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-large"
8
+
9
 
10
 
11
  class Config:
scripts/products10k_captions.py CHANGED
@@ -1,15 +1,52 @@
1
  from datasets import load_dataset
2
- from config import PRODUCTS_10k_DATASET
3
- from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from tqdm import tqdm
5
  import torch
6
 
7
 
 
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
- dataset = load_dataset(PRODUCTS_10k_DATASET)
10
 
11
 
12
- def image_captioning(processor , )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
 
 
1
  from datasets import load_dataset
2
+ from config import (PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME)
3
+ from transformers import (BlipProcessor, BlipForConditionalGeneration)
4
  from tqdm import tqdm
5
  import torch
6
 
7
 
8
+
9
+
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
11
 
12
 
13
+
14
+ class ImageCaptioner:
15
+
16
+ def __init__(self, dataset:str,processor:str,model:str):
17
+ self.dataset = load_dataset(dataset)
18
+ self.processor = BlipProcessor.from_pretrained(processor)
19
+ self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
20
+
21
+
22
+ def process_dataset(self):
23
+ self.dataset = self.dataset.rename_column(original_column_name='pixel_values',new_column_name='image')
24
+ self.dataset = self.dataset.remove_columns(column_names=['label'])
25
+ return self.dataset
26
+
27
+
28
+ def generate_captions(self):
29
+ self.dataset = self.process_dataset()
30
+ self.dataset['image']=[image.convert("RGB") for image in self.dataset["image"]]
31
+ print(self.dataset['image'][0])
32
+ for image in tqdm(self.dataset['image']):
33
+ inputs = self.processor(image, return_tensors="pt").to(device)
34
+ out = self.model(**inputs)
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+ ic = ImageCaptioner(dataset=PRODUCTS_10k_DATASET,processor=CAPTIONING_MODEL_NAME,model=CAPTIONING_MODEL_NAME)
45
+
46
+
47
+
48
+
49
+
50
 
51
 
52