Spaces:
Running
Running
Commit
•
ca2a4e0
1
Parent(s):
ad35dba
refactor: Update image captioning script to use Salesforce/blip-image-captioning-large model
Browse files
product_diffusion_api/__pycache__/endpoints.cpython-310.pyc
CHANGED
Binary files a/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc and b/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc differ
|
|
product_diffusion_api/endpoints.py
CHANGED
@@ -5,6 +5,7 @@ from routers import painting
|
|
5 |
|
6 |
|
7 |
|
|
|
8 |
app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product_diffusion/docs')
|
9 |
app.add_middleware(
|
10 |
CORSMiddleware,
|
@@ -38,4 +39,5 @@ async def root():
|
|
38 |
|
39 |
@app.get("/health")
|
40 |
def check_health():
|
41 |
-
return {"status": "ok"}
|
|
|
|
5 |
|
6 |
|
7 |
|
8 |
+
|
9 |
app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product_diffusion/docs')
|
10 |
app.add_middleware(
|
11 |
CORSMiddleware,
|
|
|
39 |
|
40 |
@app.get("/health")
|
41 |
def check_health():
|
42 |
+
return {"status": "ok"}
|
43 |
+
|
product_diffusion_api/routers/__pycache__/painting.cpython-310.pyc
ADDED
Binary file (243 Bytes). View file
|
|
scripts/products10k_captions.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
1 |
from datasets import load_dataset, Dataset
|
2 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
3 |
from tqdm import tqdm
|
4 |
-
from config import PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME
|
5 |
-
import torch
|
6 |
|
|
|
|
|
7 |
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
|
@@ -25,12 +26,13 @@ class ImageCaptioner:
|
|
25 |
|
26 |
Methods:
|
27 |
process_dataset: Preprocesses the dataset.
|
28 |
-
|
29 |
-
|
30 |
"""
|
31 |
|
32 |
def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
|
33 |
-
self.dataset = load_dataset(dataset, split="
|
|
|
34 |
self.processor = BlipProcessor.from_pretrained(processor)
|
35 |
self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
|
36 |
self.prompt = prompt
|
@@ -41,34 +43,48 @@ class ImageCaptioner:
|
|
41 |
|
42 |
Returns:
|
43 |
The preprocessed dataset.
|
44 |
-
|
45 |
"""
|
46 |
-
|
|
|
|
|
|
|
47 |
if "label" in self.dataset.column_names:
|
48 |
self.dataset = self.dataset.remove_columns(["label"])
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return self.dataset
|
50 |
|
51 |
-
def
|
52 |
"""
|
53 |
-
Generates
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
|
|
|
|
58 |
"""
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
|
67 |
-
self.dataset[idx]["text"] = blip_caption
|
68 |
-
|
69 |
|
70 |
-
|
|
|
|
|
71 |
|
|
|
|
|
|
|
|
|
|
|
72 |
return self.dataset
|
73 |
|
74 |
# Initialize ImageCaptioner
|
@@ -76,10 +92,11 @@ ic = ImageCaptioner(
|
|
76 |
dataset=PRODUCTS_10k_DATASET,
|
77 |
processor=CAPTIONING_MODEL_NAME,
|
78 |
model=CAPTIONING_MODEL_NAME,
|
79 |
-
prompt='
|
80 |
)
|
81 |
|
82 |
# Generate captions for the dataset
|
83 |
products10k_dataset = ic.generate_captions()
|
84 |
-
|
85 |
-
|
|
|
|
1 |
+
import torch
|
2 |
from datasets import load_dataset, Dataset
|
3 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
4 |
from tqdm import tqdm
|
|
|
|
|
5 |
|
6 |
+
# Assuming PRODUCTS_10k_DATASET and CAPTIONING_MODEL_NAME are defined in config.py
|
7 |
+
from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
|
8 |
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
|
|
26 |
|
27 |
Methods:
|
28 |
process_dataset: Preprocesses the dataset.
|
29 |
+
generate_caption: Generates a caption for a single image.
|
30 |
+
generate_captions: Generates captions for all images in the dataset.
|
31 |
"""
|
32 |
|
33 |
def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
|
34 |
+
self.dataset = load_dataset(dataset, split="test")
|
35 |
+
self.dataset = self.dataset.select(range(10000)) # For demonstration purposes
|
36 |
self.processor = BlipProcessor.from_pretrained(processor)
|
37 |
self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
|
38 |
self.prompt = prompt
|
|
|
43 |
|
44 |
Returns:
|
45 |
The preprocessed dataset.
|
|
|
46 |
"""
|
47 |
+
# Check if 'image' column exists, otherwise use 'pixel_values' if it exists
|
48 |
+
image_column = "image" if "image" in self.dataset.column_names else "pixel_values"
|
49 |
+
self.dataset = self.dataset.rename_column(image_column, "image")
|
50 |
+
|
51 |
if "label" in self.dataset.column_names:
|
52 |
self.dataset = self.dataset.remove_columns(["label"])
|
53 |
+
|
54 |
+
# Add an empty 'text' column for captions if it doesn't exist
|
55 |
+
if "text" not in self.dataset.column_names:
|
56 |
+
new_column = [""] * len(self.dataset)
|
57 |
+
self.dataset = self.dataset.add_column("text", new_column)
|
58 |
+
|
59 |
return self.dataset
|
60 |
|
61 |
+
def generate_caption(self, example):
|
62 |
"""
|
63 |
+
Generates a caption for a single image.
|
64 |
|
65 |
+
Args:
|
66 |
+
example (dict): A dictionary containing the image data.
|
67 |
|
68 |
+
Returns:
|
69 |
+
dict: The dictionary with the generated caption.
|
70 |
"""
|
71 |
+
image = example["image"].convert("RGB")
|
72 |
+
inputs = self.processor(images=image, return_tensors="pt").to(device)
|
73 |
+
prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
|
74 |
+
outputs = self.model.generate(**inputs, **prompt_inputs)
|
75 |
+
blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
|
76 |
+
example["text"] = blip_caption
|
77 |
+
return example
|
|
|
|
|
|
|
78 |
|
79 |
+
def generate_captions(self):
|
80 |
+
"""
|
81 |
+
Generates captions for all images in the dataset.
|
82 |
|
83 |
+
Returns:
|
84 |
+
Dataset: The dataset with generated captions.
|
85 |
+
"""
|
86 |
+
self.dataset = self.process_dataset()
|
87 |
+
self.dataset = self.dataset.map(self.generate_caption, batched=False)
|
88 |
return self.dataset
|
89 |
|
90 |
# Initialize ImageCaptioner
|
|
|
92 |
dataset=PRODUCTS_10k_DATASET,
|
93 |
processor=CAPTIONING_MODEL_NAME,
|
94 |
model=CAPTIONING_MODEL_NAME,
|
95 |
+
prompt='Commercial photography of'
|
96 |
)
|
97 |
|
98 |
# Generate captions for the dataset
|
99 |
products10k_dataset = ic.generate_captions()
|
100 |
+
|
101 |
+
# Save the dataset to the hub
|
102 |
+
products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")
|