Spaces:
Runtime error
Runtime error
VikramSingh178
commited on
Commit
β’
fc250c3
1
Parent(s):
d2a2d86
Update image captioning script to use Salesforce/blip-image-captioning-large model
Browse files
scripts/__pycache__/config.cpython-310.pyc
CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
|
|
scripts/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (919 Bytes). View file
|
|
scripts/config.py
CHANGED
@@ -4,6 +4,8 @@ VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
|
|
4 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
5 |
PROJECT_NAME = "Product Photography"
|
6 |
PRODUCTS_10k_DATASET = "amaye15/Products-10k"
|
|
|
|
|
7 |
|
8 |
|
9 |
class Config:
|
|
|
4 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
5 |
PROJECT_NAME = "Product Photography"
|
6 |
PRODUCTS_10k_DATASET = "amaye15/Products-10k"
|
7 |
+
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-large"
|
8 |
+
|
9 |
|
10 |
|
11 |
class Config:
|
scripts/products10k_captions.py
CHANGED
@@ -1,15 +1,52 @@
|
|
1 |
from datasets import load_dataset
|
2 |
-
from config import PRODUCTS_10k_DATASET
|
3 |
-
from transformers import BlipProcessor, BlipForConditionalGeneration
|
4 |
from tqdm import tqdm
|
5 |
import torch
|
6 |
|
7 |
|
|
|
|
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
-
dataset = load_dataset(PRODUCTS_10k_DATASET)
|
10 |
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
|
|
|
1 |
from datasets import load_dataset
|
2 |
+
from config import (PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME)
|
3 |
+
from transformers import (BlipProcessor, BlipForConditionalGeneration)
|
4 |
from tqdm import tqdm
|
5 |
import torch
|
6 |
|
7 |
|
8 |
+
|
9 |
+
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
11 |
|
12 |
|
13 |
+
|
14 |
+
class ImageCaptioner:
|
15 |
+
|
16 |
+
def __init__(self, dataset:str,processor:str,model:str):
|
17 |
+
self.dataset = load_dataset(dataset)
|
18 |
+
self.processor = BlipProcessor.from_pretrained(processor)
|
19 |
+
self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
|
20 |
+
|
21 |
+
|
22 |
+
def process_dataset(self):
|
23 |
+
self.dataset = self.dataset.rename_column(original_column_name='pixel_values',new_column_name='image')
|
24 |
+
self.dataset = self.dataset.remove_columns(column_names=['label'])
|
25 |
+
return self.dataset
|
26 |
+
|
27 |
+
|
28 |
+
def generate_captions(self):
|
29 |
+
self.dataset = self.process_dataset()
|
30 |
+
self.dataset['image']=[image.convert("RGB") for image in self.dataset["image"]]
|
31 |
+
print(self.dataset['image'][0])
|
32 |
+
for image in tqdm(self.dataset['image']):
|
33 |
+
inputs = self.processor(image, return_tensors="pt").to(device)
|
34 |
+
out = self.model(**inputs)
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
ic = ImageCaptioner(dataset=PRODUCTS_10k_DATASET,processor=CAPTIONING_MODEL_NAME,model=CAPTIONING_MODEL_NAME)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
|
51 |
|
52 |
|