VikramSingh178 commited on
Commit
40de55c
1 Parent(s): b427e12

chore: Add augment_image function to utils.py

Browse files
api/routers/painting.py CHANGED
@@ -12,7 +12,7 @@ from hydra import compose, initialize
12
  from async_batcher.batcher import AsyncBatcher
13
  import json
14
  from functools import lru_cache
15
-
16
  pl.seed_everything(42)
17
  router = APIRouter()
18
 
@@ -31,6 +31,7 @@ def load_pipeline_wrapper():
31
  """
32
  pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
33
  return pipeline
 
34
  inpainting_pipeline = load_pipeline_wrapper()
35
 
36
  class InpaintingRequest(BaseModel):
@@ -44,8 +45,7 @@ class InpaintingRequest(BaseModel):
44
  guidance_scale: float = Field(..., description="Guidance scale for inference")
45
  mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
46
  num_images: int = Field(..., description="Number of images to generate")
47
- use_augmentation: bool = Field(True, description="Whether to use image augmentation")
48
-
49
  class InpaintingBatchRequestModel(BaseModel):
50
  """
51
  Model representing a batch request for inpainting inference.
@@ -68,35 +68,14 @@ async def save_image(image: UploadFile) -> str:
68
  f.write(await image.read())
69
  return file_path
70
 
71
- def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
72
- """
73
- Augment an image by extending its dimensions and generating masks.
74
-
75
- Args:
76
- image_path (str): Path to the image file.
77
- target_width (int): Target width for augmentation.
78
- target_height (int): Target height for augmentation.
79
- roi_scale (float): Scale factor for region of interest.
80
- segmentation_model_name (str): Name of the segmentation model.
81
- detection_model_name (str): Name of the detection model.
82
-
83
- Returns:
84
- Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
85
- """
86
- image = Image.open(image_path)
87
- image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
88
- image = image_augmentation.extend_image(image)
89
- mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
90
- inverted_mask = image_augmentation.invert_mask(mask)
91
- return image, inverted_mask
92
-
93
- def run_inference(cfg, image_path: str, request: InpaintingRequest):
94
  """
95
  Run inference using an inpainting pipeline on an image.
96
 
97
  Args:
98
  cfg (dict): Configuration dictionary.
99
  image_path (str): Path to the image file.
 
100
  request (InpaintingRequest): Pydantic model containing inference parameters.
101
 
102
  Returns:
@@ -105,17 +84,8 @@ def run_inference(cfg, image_path: str, request: InpaintingRequest):
105
  Raises:
106
  ValueError: If an invalid mode is provided.
107
  """
108
- if request.use_augmentation:
109
- image, mask_image = augment_image(image_path,
110
- cfg['target_width'],
111
- cfg['target_height'],
112
- cfg['roi_scale'],
113
- cfg['segmentation_model'],
114
- cfg['detection_model'])
115
- else:
116
- image = Image.open(image_path)
117
- mask_image = None
118
-
119
  painting_pipeline = AutoPaintingPipeline(
120
  pipeline=inpainting_pipeline,
121
  image=image,
@@ -137,26 +107,33 @@ def run_inference(cfg, image_path: str, request: InpaintingRequest):
137
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
138
 
139
  class InpaintingBatcher(AsyncBatcher):
140
- async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
 
 
 
141
  """
142
  Process a batch of images and requests for inpainting inference.
143
 
144
  Args:
145
- batch (Tuple[List[str], List[InpaintingRequest]]): Tuple of image paths and corresponding requests.
146
 
147
  Returns:
148
  List[Dict[str, Any]]: List of resulting images in the specified mode ('b64_json' or 's3_json').
149
  """
150
- image_paths, requests = batch
151
  results = []
152
- for image_path, request in zip(image_paths, requests):
153
- result = run_inference(cfg, image_path, request)
154
- results.append(result)
 
 
 
155
  return results
156
 
157
  @router.post("/inpainting")
158
  async def inpainting_inference(
159
  image: UploadFile = File(...),
 
160
  request_data: str = Form(...),
161
  ):
162
  """
@@ -164,6 +141,7 @@ async def inpainting_inference(
164
 
165
  Args:
166
  image (UploadFile): Uploaded image file.
 
167
  request_data (str): JSON string of the request parameters.
168
 
169
  Returns:
@@ -174,9 +152,10 @@ async def inpainting_inference(
174
  """
175
  try:
176
  image_path = await save_image(image)
 
177
  request_dict = json.loads(request_data)
178
  request = InpaintingRequest(**request_dict)
179
- result = run_inference(cfg, image_path, request)
180
  return result
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=str(e))
@@ -184,6 +163,7 @@ async def inpainting_inference(
184
  @router.post("/inpainting/batch")
185
  async def inpainting_batch_inference(
186
  images: List[UploadFile] = File(...),
 
187
  request_data: str = Form(...),
188
  ):
189
  """
@@ -191,6 +171,7 @@ async def inpainting_batch_inference(
191
 
192
  Args:
193
  images (List[UploadFile]): List of uploaded image files.
 
194
  request_data (str): JSON string of the request parameters.
195
 
196
  Returns:
@@ -204,13 +185,14 @@ async def inpainting_batch_inference(
204
  batch_request = InpaintingBatchRequestModel(**request_dict)
205
  requests = batch_request.requests
206
 
207
- if len(images) != len(requests):
208
- raise HTTPException(status_code=400, detail="The number of images and requests must match.")
209
 
210
  batcher = InpaintingBatcher(max_batch_size=64)
211
- image_paths = [await save_image(image) for image in images]
212
- results = batcher.process_batch((image_paths, requests))
 
213
 
214
  return results
215
  except Exception as e:
216
- raise HTTPException(status_code=500, detail=str(e))
 
12
  from async_batcher.batcher import AsyncBatcher
13
  import json
14
  from functools import lru_cache
15
+ import asyncio
16
  pl.seed_everything(42)
17
  router = APIRouter()
18
 
 
31
  """
32
  pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
33
  return pipeline
34
+
35
  inpainting_pipeline = load_pipeline_wrapper()
36
 
37
  class InpaintingRequest(BaseModel):
 
45
  guidance_scale: float = Field(..., description="Guidance scale for inference")
46
  mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
47
  num_images: int = Field(..., description="Number of images to generate")
48
+
 
49
  class InpaintingBatchRequestModel(BaseModel):
50
  """
51
  Model representing a batch request for inpainting inference.
 
68
  f.write(await image.read())
69
  return file_path
70
 
71
+ def run_inference(cfg, image_path: str, mask_image_path: str, request: InpaintingRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
  Run inference using an inpainting pipeline on an image.
74
 
75
  Args:
76
  cfg (dict): Configuration dictionary.
77
  image_path (str): Path to the image file.
78
+ mask_image_path (str): Path to the mask image file.
79
  request (InpaintingRequest): Pydantic model containing inference parameters.
80
 
81
  Returns:
 
84
  Raises:
85
  ValueError: If an invalid mode is provided.
86
  """
87
+ image = Image.open(image_path)
88
+ mask_image = Image.open(mask_image_path)
 
 
 
 
 
 
 
 
 
89
  painting_pipeline = AutoPaintingPipeline(
90
  pipeline=inpainting_pipeline,
91
  image=image,
 
107
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
108
 
109
  class InpaintingBatcher(AsyncBatcher):
110
+ def __init__(self, max_batch_size: int):
111
+ super().__init__(max_batch_size)
112
+
113
+ async def process_batch(self, batch: Tuple[List[str], List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
114
  """
115
  Process a batch of images and requests for inpainting inference.
116
 
117
  Args:
118
+ batch (Tuple[List[str], List[str], List[InpaintingRequest]]): Tuple of image paths, mask image paths, and corresponding requests.
119
 
120
  Returns:
121
  List[Dict[str, Any]]: List of resulting images in the specified mode ('b64_json' or 's3_json').
122
  """
123
+ image_paths, mask_image_paths, requests = batch
124
  results = []
125
+ for image_path, mask_image_path, request in zip(image_paths, mask_image_paths, requests):
126
+ try:
127
+ result = run_inference(cfg, image_path, mask_image_path, request)
128
+ results.append(result)
129
+ except Exception as e:
130
+ results.append({"error": str(e)})
131
  return results
132
 
133
  @router.post("/inpainting")
134
  async def inpainting_inference(
135
  image: UploadFile = File(...),
136
+ mask_image: UploadFile = File(...),
137
  request_data: str = Form(...),
138
  ):
139
  """
 
141
 
142
  Args:
143
  image (UploadFile): Uploaded image file.
144
+ mask_image (UploadFile): Uploaded mask image file.
145
  request_data (str): JSON string of the request parameters.
146
 
147
  Returns:
 
152
  """
153
  try:
154
  image_path = await save_image(image)
155
+ mask_image_path = await save_image(mask_image)
156
  request_dict = json.loads(request_data)
157
  request = InpaintingRequest(**request_dict)
158
+ result = run_inference(cfg, image_path, mask_image_path, request)
159
  return result
160
  except Exception as e:
161
  raise HTTPException(status_code=500, detail=str(e))
 
163
  @router.post("/inpainting/batch")
164
  async def inpainting_batch_inference(
165
  images: List[UploadFile] = File(...),
166
+ mask_images: List[UploadFile] = File(...),
167
  request_data: str = Form(...),
168
  ):
169
  """
 
171
 
172
  Args:
173
  images (List[UploadFile]): List of uploaded image files.
174
+ mask_images (List[UploadFile]): List of uploaded mask image files.
175
  request_data (str): JSON string of the request parameters.
176
 
177
  Returns:
 
185
  batch_request = InpaintingBatchRequestModel(**request_dict)
186
  requests = batch_request.requests
187
 
188
+ if len(images) != len(requests) or len(images) != len(mask_images):
189
+ raise HTTPException(status_code=400, detail="The number of images, mask images, and requests must match.")
190
 
191
  batcher = InpaintingBatcher(max_batch_size=64)
192
+ image_paths = await asyncio.gather(*[save_image(image) for image in images])
193
+ mask_image_paths = await asyncio.gather(*[save_image(mask_image) for mask_image in mask_images])
194
+ results = await batcher.process_batch((image_paths, mask_image_paths, requests))
195
 
196
  return results
197
  except Exception as e:
198
+ raise HTTPException(status_code=500, detail=str(e))
ui/__pycache__/ui.cpython-311.pyc ADDED
Binary file (9.18 kB). View file
 
ui/ui.py CHANGED
@@ -1,16 +1,15 @@
1
  import gradio as gr
2
  import requests
3
- from pydantic import BaseModel
4
  from diffusers.utils import load_image
5
  from io import BytesIO
6
-
7
-
8
 
9
  sdxl_inference_endpoint = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/sdxl_v0_lora_inference'
10
  sdxl_batch_inference_endpoint = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/sdxl_v0_lora_inference/batch'
11
  kandinsky_inpainting_inference = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/inpainting'
12
 
13
- # Define the InpaintingRequest model
14
  class InputRequest(BaseModel):
15
  prompt: str
16
  num_inference_steps: int
@@ -20,15 +19,15 @@ class InputRequest(BaseModel):
20
  mode: str
21
 
22
  class InpaintingRequest(BaseModel):
23
- prompt: str
24
- negative_prompt: str
25
- num_inference_steps: int
26
- strength: float
27
- guidance_scale: float
28
- mode: str
 
29
 
30
  async def generate_sdxl_lora_image(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode):
31
- # Prepare the payload for SDXL LORA API
32
  payload = InputRequest(
33
  prompt=prompt,
34
  negative_prompt=negative_prompt,
@@ -44,24 +43,46 @@ async def generate_sdxl_lora_image(prompt, negative_prompt, num_inference_steps,
44
  image = load_image(url)
45
  return image
46
 
 
 
 
 
 
 
 
 
47
 
48
-
49
- def generate_outpainting(prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, image):
50
- # Convert the image to bytes
 
51
  img_byte_arr = BytesIO()
52
- image.save(img_byte_arr, format='PNG')
53
  img_byte_arr = img_byte_arr.getvalue()
 
 
 
 
54
 
55
  # Prepare the payload for multipart/form-data
56
  files = {
57
  'image': ('image.png', img_byte_arr, 'image/png'),
58
- 'prompt': (None, prompt),
59
- 'negative_prompt': (None, negative_prompt),
60
- 'num_inference_steps': (None, str(num_inference_steps)),
61
- 'strength': (None, str(strength)),
62
- 'guidance_scale': (None, str(guidance_scale)),
63
- 'mode': (None, mode)
64
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  response = requests.post(kandinsky_inpainting_inference, files=files)
67
  response.raise_for_status()
@@ -70,44 +91,39 @@ def generate_outpainting(prompt, negative_prompt, num_inference_steps, strength,
70
  image = load_image(url)
71
  return image
72
 
73
-
74
-
75
  with gr.Blocks(theme='VikramSingh178/Webui-Theme') as demo:
76
  with gr.Tab("SdxL-Lora"):
77
- with gr.Row():
78
  with gr.Column():
79
- with gr.Group():
80
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
81
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
82
- num_inference_steps = gr.Slider(minimum=1, maximum=1000, step=1, value=20, label="Inference Steps")
83
- guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=7.5, label="Guidance Scale")
84
- num_images = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Images")
85
- mode = gr.Dropdown(choices=["s3_json", "b64_json"], value="s3_json", label="Mode")
86
- generate_button = gr.Button("Generate Image",variant='primary')
87
 
88
  with gr.Column(scale=1):
89
-
90
- image_preview = gr.Image(label="Generated Image",show_download_button=True,show_share_button=True,container=True)
91
  generate_button.click(generate_sdxl_lora_image, inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode], outputs=[image_preview])
92
 
93
- with gr.Tab("Generate AI Background"):
94
  with gr.Row():
95
  with gr.Column():
96
- with gr.Group():
97
- image_input = gr.Image(type="pil", label="Upload Image")
98
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
99
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
100
- num_inference_steps = gr.Slider(minimum=1, maximum=500, step=1, value=20, label="Inference Steps")
101
- guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=7.5, label="Guidance Scale")
102
- strength = gr.Slider(minimum=0.1, maximum=1, step=0.1, value=1, label="Strength")
103
- mode = gr.Dropdown(choices=["s3_json", "b64_json"], value="s3_json", label="Mode")
104
- generate_button = gr.Button("Generate Background", variant='primary')
 
105
 
106
  with gr.Column(scale=1):
107
-
108
- image_preview = gr.Image(label="Image", show_download_button=True, show_share_button=True, container=True)
109
- generate_button.click(generate_outpainting, inputs=[prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, image_input], outputs=[image_preview])
110
-
111
- demo.launch()
112
-
113
 
 
 
1
  import gradio as gr
2
  import requests
3
+ from pydantic import BaseModel, Field
4
  from diffusers.utils import load_image
5
  from io import BytesIO
6
+ import json
7
+ import numpy as np
8
 
9
  sdxl_inference_endpoint = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/sdxl_v0_lora_inference'
10
  sdxl_batch_inference_endpoint = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/sdxl_v0_lora_inference/batch'
11
  kandinsky_inpainting_inference = 'https://vikramsingh178-picpilot-server.hf.space/api/v1/product-diffusion/inpainting'
12
 
 
13
  class InputRequest(BaseModel):
14
  prompt: str
15
  num_inference_steps: int
 
19
  mode: str
20
 
21
  class InpaintingRequest(BaseModel):
22
+ prompt: str = Field(..., description="Prompt text for inference")
23
+ negative_prompt: str = Field(..., description="Negative prompt text for inference")
24
+ num_inference_steps: int = Field(..., description="Number of inference steps")
25
+ strength: float = Field(..., description="Strength of the inference")
26
+ guidance_scale: float = Field(..., description="Guidance scale for inference")
27
+ mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
28
+ num_images: int = Field(..., description="Number of images to generate")
29
 
30
  async def generate_sdxl_lora_image(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode):
 
31
  payload = InputRequest(
32
  prompt=prompt,
33
  negative_prompt=negative_prompt,
 
43
  image = load_image(url)
44
  return image
45
 
46
+ def process_masked_image(img):
47
+ base_image = img["image"]
48
+ mask = img["mask"]
49
+
50
+ # Convert mask to binary (0 or 255)
51
+ mask = np.where(mask > 0, 255, 0).astype(np.uint8)
52
+
53
+ return base_image, mask
54
 
55
+ def generate_outpainting(prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, num_images, masked_image):
56
+ base_image, mask = process_masked_image(masked_image)
57
+
58
+ # Convert the images to bytes
59
  img_byte_arr = BytesIO()
60
+ base_image.save(img_byte_arr, format='PNG')
61
  img_byte_arr = img_byte_arr.getvalue()
62
+
63
+ mask_byte_arr = BytesIO()
64
+ mask_image = gr.processing_utils.encode_pil_to_base64(mask)
65
+ mask_byte_arr = mask_image.getvalue()
66
 
67
  # Prepare the payload for multipart/form-data
68
  files = {
69
  'image': ('image.png', img_byte_arr, 'image/png'),
70
+ 'mask_image': ('mask.png', mask_byte_arr, 'image/png'),
 
 
 
 
 
71
  }
72
+
73
+ # Prepare the request data
74
+ request_data = InpaintingRequest(
75
+ prompt=prompt,
76
+ negative_prompt=negative_prompt,
77
+ num_inference_steps=num_inference_steps,
78
+ strength=strength,
79
+ guidance_scale=guidance_scale,
80
+ mode=mode,
81
+ num_images=num_images
82
+ ).dict()
83
+
84
+ # Add the JSON-encoded request data to the files dictionary
85
+ files['request_data'] = ('request_data.json', json.dumps(request_data), 'application/json')
86
 
87
  response = requests.post(kandinsky_inpainting_inference, files=files)
88
  response.raise_for_status()
 
91
  image = load_image(url)
92
  return image
93
 
 
 
94
  with gr.Blocks(theme='VikramSingh178/Webui-Theme') as demo:
95
  with gr.Tab("SdxL-Lora"):
96
+ with gr.Row():
97
  with gr.Column():
98
+ with gr.Group():
99
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
100
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
101
+ num_inference_steps = gr.Slider(minimum=1, maximum=1000, step=1, value=20, label="Inference Steps")
102
+ guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=7.5, label="Guidance Scale")
103
+ num_images = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Images")
104
+ mode = gr.Dropdown(choices=["s3_json", "b64_json"], value="s3_json", label="Mode")
105
+ generate_button = gr.Button("Generate Image", variant='primary')
106
 
107
  with gr.Column(scale=1):
108
+ image_preview = gr.Image(label="Generated Image", show_download_button=True, show_share_button=True, container=True)
 
109
  generate_button.click(generate_sdxl_lora_image, inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode], outputs=[image_preview])
110
 
111
+ with gr.Tab("Inpainting"):
112
  with gr.Row():
113
  with gr.Column():
114
+ with gr.Group():
115
+ masked_image = gr.ImageMask(label="Upload Image and Draw Mask")
116
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
117
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
118
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=20, label="Inference Steps")
119
+ strength = gr.Slider(minimum=0.1, maximum=1, step=0.1, value=0.8, label="Strength")
120
+ guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=7.5, label="Guidance Scale")
121
+ num_images = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Images")
122
+ mode = gr.Dropdown(choices=["s3_json", "b64_json"], value="s3_json", label="Mode")
123
+ generate_button = gr.Button("Generate Inpainting", variant='primary')
124
 
125
  with gr.Column(scale=1):
126
+ image_preview = gr.Image(label="Inpainted Image", show_download_button=True, show_share_button=True, container=True)
127
+ generate_button.click(generate_outpainting, inputs=[prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, num_images, masked_image], outputs=[image_preview])
 
 
 
 
128
 
129
+ demo.launch()
ui/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scripts.api_utils import ImageAugmentation
2
+ from PIL import Image
3
+
4
+
5
+
6
+
7
+ def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
8
+ """
9
+ Augment an image by extending its dimensions and generating masks.
10
+
11
+ Args:
12
+ image_path (str): Path to the image file.
13
+ target_width (int): Target width for augmentation.
14
+ target_height (int): Target height for augmentation.
15
+ roi_scale (float): Scale factor for region of interest.
16
+ segmentation_model_name (str): Name of the segmentation model.
17
+ detection_model_name (str): Name of the detection model.
18
+
19
+ Returns:
20
+ Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
21
+ """
22
+ image = Image.open(image_path)
23
+ image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
24
+ image = image_augmentation.extend_image(image)
25
+ mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
26
+ inverted_mask = image_augmentation.invert_mask(mask)
27
+ return image, inverted_mask