VikramSingh178 commited on
Commit
ca2a4e0
1 Parent(s): ad35dba

refactor: Update image captioning script to use Salesforce/blip-image-captioning-large model

Browse files
product_diffusion_api/__pycache__/endpoints.cpython-310.pyc CHANGED
Binary files a/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc and b/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc differ
 
product_diffusion_api/endpoints.py CHANGED
@@ -5,6 +5,7 @@ from routers import painting
5
 
6
 
7
 
 
8
  app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product_diffusion/docs')
9
  app.add_middleware(
10
  CORSMiddleware,
@@ -38,4 +39,5 @@ async def root():
38
 
39
  @app.get("/health")
40
  def check_health():
41
- return {"status": "ok"}
 
 
5
 
6
 
7
 
8
+
9
  app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product_diffusion/docs')
10
  app.add_middleware(
11
  CORSMiddleware,
 
39
 
40
  @app.get("/health")
41
  def check_health():
42
+ return {"status": "ok"}
43
+
product_diffusion_api/routers/__pycache__/painting.cpython-310.pyc ADDED
Binary file (243 Bytes). View file
 
scripts/products10k_captions.py CHANGED
@@ -1,9 +1,10 @@
 
1
  from datasets import load_dataset, Dataset
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from tqdm import tqdm
4
- from config import PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME
5
- import torch
6
 
 
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
@@ -25,12 +26,13 @@ class ImageCaptioner:
25
 
26
  Methods:
27
  process_dataset: Preprocesses the dataset.
28
- generate_captions: Generates captions for the images in the dataset.
29
-
30
  """
31
 
32
  def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
33
- self.dataset = load_dataset(dataset, split="train")
 
34
  self.processor = BlipProcessor.from_pretrained(processor)
35
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
36
  self.prompt = prompt
@@ -41,34 +43,48 @@ class ImageCaptioner:
41
 
42
  Returns:
43
  The preprocessed dataset.
44
-
45
  """
46
- self.dataset = self.dataset.rename_column("pixel_values", "image")
 
 
 
47
  if "label" in self.dataset.column_names:
48
  self.dataset = self.dataset.remove_columns(["label"])
 
 
 
 
 
 
49
  return self.dataset
50
 
51
- def generate_captions(self):
52
  """
53
- Generates captions for the images in the dataset.
54
 
55
- Returns:
56
- The dataset with captions.
57
 
 
 
58
  """
59
- self.dataset = self.process_dataset()
60
-
61
- for idx in tqdm(range(len(self.dataset))):
62
- image = self.dataset[idx]["image"].convert("RGB")
63
- inputs = self.processor(images=image, return_tensors="pt").to(device)
64
- prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
65
- outputs = self.model.generate(**inputs, **prompt_inputs)
66
- blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
67
- self.dataset[idx]["text"] = blip_caption
68
-
69
 
70
-
 
 
71
 
 
 
 
 
 
72
  return self.dataset
73
 
74
  # Initialize ImageCaptioner
@@ -76,10 +92,11 @@ ic = ImageCaptioner(
76
  dataset=PRODUCTS_10k_DATASET,
77
  processor=CAPTIONING_MODEL_NAME,
78
  model=CAPTIONING_MODEL_NAME,
79
- prompt='Photography of ' # Adding the conditioning prompt
80
  )
81
 
82
  # Generate captions for the dataset
83
  products10k_dataset = ic.generate_captions()
84
- new_dataset = Dataset.from_pandas(products10k_dataset.to_pandas()) # Convert to a `datasets` Dataset if necessary
85
- new_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")
 
 
1
+ import torch
2
  from datasets import load_dataset, Dataset
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from tqdm import tqdm
 
 
5
 
6
+ # Assuming PRODUCTS_10k_DATASET and CAPTIONING_MODEL_NAME are defined in config.py
7
+ from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
 
26
 
27
  Methods:
28
  process_dataset: Preprocesses the dataset.
29
+ generate_caption: Generates a caption for a single image.
30
+ generate_captions: Generates captions for all images in the dataset.
31
  """
32
 
33
  def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
34
+ self.dataset = load_dataset(dataset, split="test")
35
+ self.dataset = self.dataset.select(range(10000)) # For demonstration purposes
36
  self.processor = BlipProcessor.from_pretrained(processor)
37
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
38
  self.prompt = prompt
 
43
 
44
  Returns:
45
  The preprocessed dataset.
 
46
  """
47
+ # Check if 'image' column exists, otherwise use 'pixel_values' if it exists
48
+ image_column = "image" if "image" in self.dataset.column_names else "pixel_values"
49
+ self.dataset = self.dataset.rename_column(image_column, "image")
50
+
51
  if "label" in self.dataset.column_names:
52
  self.dataset = self.dataset.remove_columns(["label"])
53
+
54
+ # Add an empty 'text' column for captions if it doesn't exist
55
+ if "text" not in self.dataset.column_names:
56
+ new_column = [""] * len(self.dataset)
57
+ self.dataset = self.dataset.add_column("text", new_column)
58
+
59
  return self.dataset
60
 
61
+ def generate_caption(self, example):
62
  """
63
+ Generates a caption for a single image.
64
 
65
+ Args:
66
+ example (dict): A dictionary containing the image data.
67
 
68
+ Returns:
69
+ dict: The dictionary with the generated caption.
70
  """
71
+ image = example["image"].convert("RGB")
72
+ inputs = self.processor(images=image, return_tensors="pt").to(device)
73
+ prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
74
+ outputs = self.model.generate(**inputs, **prompt_inputs)
75
+ blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
76
+ example["text"] = blip_caption
77
+ return example
 
 
 
78
 
79
+ def generate_captions(self):
80
+ """
81
+ Generates captions for all images in the dataset.
82
 
83
+ Returns:
84
+ Dataset: The dataset with generated captions.
85
+ """
86
+ self.dataset = self.process_dataset()
87
+ self.dataset = self.dataset.map(self.generate_caption, batched=False)
88
  return self.dataset
89
 
90
  # Initialize ImageCaptioner
 
92
  dataset=PRODUCTS_10k_DATASET,
93
  processor=CAPTIONING_MODEL_NAME,
94
  model=CAPTIONING_MODEL_NAME,
95
+ prompt='Commercial photography of'
96
  )
97
 
98
  # Generate captions for the dataset
99
  products10k_dataset = ic.generate_captions()
100
+
101
+ # Save the dataset to the hub
102
+ products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")