VikramSingh178 commited on
Commit
6c850b1
1 Parent(s): 47b9b86

chore: Update import statement for InpaintingRequest in painting.py and refactor code to use shared BaseModel for Painting and InpaintingRequest classes

Browse files

Former-commit-id: a23f3f613d6636adcadd1b01d7f1d61993d91251 [formerly e5c83686f1238c8a9cb66b21590a5a1ee6597665]
Former-commit-id: 2252984afe4046eea2172d7ddd6d0096cf70b07b

api/__pycache__/endpoints.cpython-310.pyc CHANGED
Binary files a/api/__pycache__/endpoints.cpython-310.pyc and b/api/__pycache__/endpoints.cpython-310.pyc differ
 
api/endpoints.py CHANGED
@@ -51,5 +51,5 @@ async def root():
51
  def check_health():
52
  return {"status": "ok"}
53
 
54
- uvicorn.run(app, host="0.0.0.0", port=8000)
55
 
 
51
  def check_health():
52
  return {"status": "ok"}
53
 
54
+
55
 
api/models/__pycache__/painting.cpython-310.pyc CHANGED
Binary files a/api/models/__pycache__/painting.cpython-310.pyc and b/api/models/__pycache__/painting.cpython-310.pyc differ
 
api/models/painting.py CHANGED
@@ -1,4 +1,5 @@
1
  from pydantic import BaseModel
 
2
 
3
 
4
  class InpaintingRequest(BaseModel):
@@ -7,3 +8,4 @@ class InpaintingRequest(BaseModel):
7
  num_inference_steps: int
8
  strength: float
9
  guidance_scale: float
 
 
1
  from pydantic import BaseModel
2
+ from fastapi import Form
3
 
4
 
5
  class InpaintingRequest(BaseModel):
 
8
  num_inference_steps: int
9
  strength: float
10
  guidance_scale: float
11
+
api/routers/__pycache__/painting.cpython-310.pyc CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-310.pyc and b/api/routers/__pycache__/painting.cpython-310.pyc differ
 
api/routers/painting.py CHANGED
@@ -1,69 +1,34 @@
 
 
1
  import sys
2
  sys.path.append("../scripts")
3
- from fastapi import APIRouter, File, UploadFile, HTTPException
4
- from pydantic import BaseModel
5
- from PIL import Image
6
- from io import BytesIO
7
- from models.painting import InpaintingRequest
8
  import uuid
9
- from inpainting_pipeline import AutoPaintingPipeline
10
- from utils import pil_to_s3_json, ImageAugmentation
11
- from hydra import compose, initialize
12
  import lightning.pytorch as pl
13
- pl.seed_everything(42)
14
-
15
- router = APIRouter()
16
-
17
-
 
 
18
 
19
- #class InpaintingRequest(BaseModel):
20
- # prompt: str
21
- # negative_prompt: str
22
- # num_inference_steps: int
23
- # strength: float
24
- # guidance_scale: float
25
 
26
- def augment_image(image, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
27
- """
28
- Augments an image with a given prompt, model, and other parameters.
29
 
30
- Parameters:
31
- - image (str): The path to the image file.
32
- - target_width (int): The desired width of the augmented image.
33
- - target_height (int): The desired height of the augmented image.
34
- - roi_scale (float): The scale factor for the region of interest.
35
 
36
- Returns:
37
- - augmented_image (PIL.Image.Image): The augmented image.
38
- - inverted_mask (PIL.Image.Image): The inverted mask generated from the augmented image.
39
- """
40
- image = Image.open(image)
41
  image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
42
  image = image_augmentation.extend_image(image)
43
  mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
44
  inverted_mask = image_augmentation.invert_mask(mask)
45
  return image, inverted_mask
46
 
47
- def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
48
- """
49
- Run inference using the provided configuration and input image.
50
-
51
- Args:
52
- cfg (dict): Configuration dictionary containing model parameters.
53
- image_path (str): Path to the input image file.
54
- prompt (str): Prompt for the inference process.
55
- negative_prompt (str): Negative prompt for the inference process.
56
- num_inference_steps (int): Number of inference steps to perform.
57
- strength (float): Strength parameter for the inference.
58
- guidance_scale (float): Guidance scale for the inference.
59
-
60
- Returns:
61
- dict: A JSON object containing the image ID and the signed URL.
62
-
63
- Raises:
64
- HTTPException: If an error occurs during the inference process.
65
-
66
- """
67
  image, mask_image = augment_image(image_path,
68
  cfg['target_width'],
69
  cfg['target_height'],
@@ -71,25 +36,89 @@ def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str,
71
  cfg['segmentation_model'],
72
  cfg['detection_model'])
73
 
74
- pipeline = AutoPaintingPipeline(model_name=cfg['model'],
75
- image=image,
76
- mask_image=mask_image,
77
- target_height=cfg['target_height'],
78
- target_width=cfg['target_width'])
79
- output = pipeline.run_inference(prompt=prompt,
 
 
80
  negative_prompt=negative_prompt,
81
  num_inference_steps=num_inference_steps,
82
  strength=strength,
83
  guidance_scale=guidance_scale)
84
- return pil_to_s3_json(output, file_name="output.png")
85
-
86
- @router.post("/kandinskyv2.2_inpainting")
87
- async def inpainting_inference(image: UploadFile = File(...),
88
- prompt: str = "",
89
- negative_prompt: str = "",
90
- num_inference_steps: int = 50,
91
- strength: float = 0.5,
92
- guidance_scale: float = 7.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  """
94
  Run the inpainting/outpainting inference pipeline.
95
 
@@ -100,6 +129,8 @@ async def inpainting_inference(image: UploadFile = File(...),
100
  - num_inference_steps: int - The number of inference steps to perform during the inpainting/outpainting process.
101
  - strength: float - The strength parameter for controlling the inpainting/outpainting process.
102
  - guidance_scale: float - The guidance scale parameter for controlling the inpainting/outpainting process.
 
 
103
 
104
  Returns:
105
  - result: The result of the inpainting/outpainting process.
@@ -113,14 +144,47 @@ async def inpainting_inference(image: UploadFile = File(...),
113
  with open(image_path, "wb") as f:
114
  f.write(image_bytes)
115
 
 
 
 
116
 
117
- with initialize(version_base=None,config_path="../../configs"):
118
- cfg = compose(config_name="inpainting")
 
 
 
 
 
 
 
 
 
119
 
120
- result = run_inference(cfg, image_path, prompt, negative_prompt, num_inference_steps, strength, guidance_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- return result
123
  except Exception as e:
124
  raise HTTPException(status_code=500, detail=str(e))
 
125
 
126
 
 
1
+ from fastapi import APIRouter, File, UploadFile, HTTPException, Form
2
+ from PIL import Image
3
  import sys
4
  sys.path.append("../scripts")
 
 
 
 
 
5
  import uuid
 
 
 
6
  import lightning.pytorch as pl
7
+ from typing import List
8
+ from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
9
+ from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
10
+ from hydra import compose, initialize
11
+ from pydantic import BaseModel
12
+ from async_batcher.batcher import AsyncBatcher
13
+ from typing import Dict
14
 
 
 
 
 
 
 
15
 
16
+ router = APIRouter()
17
+ pl.seed_everything(42)
 
18
 
19
+ with initialize(version_base=None, config_path="../../configs"):
20
+ cfg = compose(config_name="inpainting")
21
+ inpainting_pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
 
 
22
 
23
+ def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
24
+ image = Image.open(image_path)
 
 
 
25
  image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
26
  image = image_augmentation.extend_image(image)
27
  mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
28
  inverted_mask = image_augmentation.invert_mask(mask)
29
  return image, inverted_mask
30
 
31
+ def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float, mode: str, num_images: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  image, mask_image = augment_image(image_path,
33
  cfg['target_width'],
34
  cfg['target_height'],
 
36
  cfg['segmentation_model'],
37
  cfg['detection_model'])
38
 
39
+ painting_pipeline = AutoPaintingPipeline(
40
+ pipeline=inpainting_pipeline,
41
+ image=image,
42
+ mask_image=mask_image,
43
+ target_height=cfg['target_height'],
44
+ target_width=cfg['target_width']
45
+ )
46
+ output = painting_pipeline.run_inference(prompt=prompt,
47
  negative_prompt=negative_prompt,
48
  num_inference_steps=num_inference_steps,
49
  strength=strength,
50
  guidance_scale=guidance_scale)
51
+ if mode == "s3_json":
52
+ return pil_to_s3_json(output, file_name="output.png")
53
+ elif mode == "b64_json":
54
+ return pil_to_b64_json(output)
55
+ else:
56
+ raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
57
+
58
+ class InpaintingRequest(BaseModel):
59
+ prompt: str
60
+ negative_prompt: str
61
+ num_inference_steps: int
62
+ strength: float
63
+ guidance_scale: float
64
+ num_images: int = 1
65
+
66
+ class InpaintingBatcher(AsyncBatcher[List[Dict], dict]):
67
+ def __init__(self, pipeline, cfg):
68
+ self.pipeline = pipeline
69
+ self.cfg = cfg
70
+
71
+ def process_batch(self, batch: List[Dict], image_paths: List[str]) -> List[dict]:
72
+ results = []
73
+ for data, image_path in zip(batch, image_paths):
74
+ try:
75
+ image, mask_image = augment_image(
76
+ image_path,
77
+ self.cfg['target_width'],
78
+ self.cfg['target_height'],
79
+ self.cfg['roi_scale'],
80
+ self.cfg['segmentation_model'],
81
+ self.cfg['detection_model']
82
+ )
83
+
84
+ pipeline = AutoPaintingPipeline(
85
+ image=image,
86
+ mask_image=mask_image,
87
+ target_height=self.cfg['target_height'],
88
+ target_width=self.cfg['target_width']
89
+ )
90
+ output = pipeline.run_inference(
91
+ prompt=data['prompt'],
92
+ negative_prompt=data['negative_prompt'],
93
+ num_inference_steps=data['num_inference_steps'],
94
+ strength=data['strength'],
95
+ guidance_scale=data['guidance_scale']
96
+ )
97
+
98
+ if data['mode'] == "s3_json":
99
+ result = pil_to_s3_json(output, 'inpainting_image')
100
+ elif data['mode'] == "b64_json":
101
+ result = pil_to_b64_json(output)
102
+ else:
103
+ raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
104
+
105
+ results.append(result)
106
+ except Exception as e:
107
+ print(f"Error in process_batch: {e}")
108
+ raise HTTPException(status_code=500, detail="Batch inference failed")
109
+ return results
110
+
111
+ @router.post("/inpainting")
112
+ async def inpainting_inference(
113
+ image: UploadFile = File(...),
114
+ prompt: str = Form(...),
115
+ negative_prompt: str = Form(...),
116
+ num_inference_steps: int = Form(...),
117
+ strength: float = Form(...),
118
+ guidance_scale: float = Form(...),
119
+ mode: str = Form(...),
120
+ num_images: int = Form(1)
121
+ ):
122
  """
123
  Run the inpainting/outpainting inference pipeline.
124
 
 
129
  - num_inference_steps: int - The number of inference steps to perform during the inpainting/outpainting process.
130
  - strength: float - The strength parameter for controlling the inpainting/outpainting process.
131
  - guidance_scale: float - The guidance scale parameter for controlling the inpainting/outpainting process.
132
+ - mode: str - The output mode, either "s3_json" or "b64_json".
133
+ - num_images: int - The number of images to generate.
134
 
135
  Returns:
136
  - result: The result of the inpainting/outpainting process.
 
144
  with open(image_path, "wb") as f:
145
  f.write(image_bytes)
146
 
147
+ result = run_inference(
148
+ cfg, image_path, prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, num_images
149
+ )
150
 
151
+ return result
152
+ except Exception as e:
153
+ raise HTTPException(status_code=500, detail=str(e))
154
+
155
+ @router.post("/inpainting_batch")
156
+ async def inpainting_batch_inference(
157
+ batch: List[dict],
158
+ images: List[UploadFile] = File(...)
159
+ ):
160
+ """
161
+ Run batch inpainting/outpainting inference pipeline.
162
 
163
+ Parameters:
164
+ - batch: List[dict] - The batch of requests containing parameters for the inpainting/outpainting process.
165
+ - images: List[UploadFile] - The list of image files to be used for inpainting/outpainting.
166
+
167
+ Returns:
168
+ - results: The results of the inpainting/outpainting process for each request.
169
+
170
+ Raises:
171
+ - HTTPException: If an error occurs during the inpainting/outpainting process.
172
+ """
173
+ try:
174
+ image_paths = []
175
+ for image in images:
176
+ image_bytes = await image.read()
177
+ image_path = f"/tmp/{uuid.uuid4()}.png"
178
+ with open(image_path, "wb") as f:
179
+ f.write(image_bytes)
180
+ image_paths.append(image_path)
181
+
182
+ batcher = InpaintingBatcher(pipeline, cfg)
183
+ results = batcher.process_batch(batch, image_paths)
184
 
185
+ return results
186
  except Exception as e:
187
  raise HTTPException(status_code=500, detail=str(e))
188
+
189
 
190
 
configs/inpainting.yaml CHANGED
@@ -1,5 +1,5 @@
1
- segmentation_model : 'facebook/sam-vit-huge'
2
- detection_model : 'yolov8l'
3
  model : 'kandinsky-community/kandinsky-2-2-decoder-inpaint'
4
  target_width : 2560
5
  target_height : 1472
 
1
+ segmentation_model : 'facebook/sam-vit-base'
2
+ detection_model : 'yolov8s'
3
  model : 'kandinsky-community/kandinsky-2-2-decoder-inpaint'
4
  target_width : 2560
5
  target_height : 1472
outputs/mask.jpg CHANGED
outputs/output.jpg CHANGED
scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/__pycache__/inpainting_pipeline.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc differ
 
scripts/config.py CHANGED
@@ -9,6 +9,7 @@ CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
  SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
10
  DETECTION_MODEL_NAME = "yolov8l"
11
  ENABLE_COMPILE = False
 
12
 
13
 
14
 
 
9
  SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
10
  DETECTION_MODEL_NAME = "yolov8l"
11
  ENABLE_COMPILE = False
12
+ INPAINTING_MODEL_NAME = ''
13
 
14
 
15
 
scripts/inpainting_pipeline.py CHANGED
@@ -1,81 +1,80 @@
1
  import torch
2
- from diffusers import AutoPipelineForInpainting,DiffusionPipeline
3
  from diffusers.utils import load_image
4
- from utils import (accelerator, ImageAugmentation, clear_memory)
5
  import hydra
6
  from omegaconf import DictConfig
7
  from PIL import Image
8
  from functools import lru_cache
9
 
10
-
 
 
 
 
 
 
 
11
 
12
  class AutoPaintingPipeline:
13
- """
14
- AutoPaintingPipeline class represents a pipeline for auto painting using an inpainting model from diffusers.
15
-
16
- Args:
17
- model_name (str): The name of the pretrained inpainting model.
18
- image (Image): The input image to be processed.
19
- mask_image (Image): The mask image indicating the areas to be inpainted.
20
- """
21
-
22
- def __init__(self, model_name: str, image: Image, mask_image: Image,target_width: int, target_height: int):
23
- self.model_name = model_name
24
- self.device = accelerator()
25
- self.pipeline = AutoPipelineForInpainting.from_pretrained(self.model_name, torch_dtype=torch.float16)
26
- self.image = load_image(image)
27
- self.mask_image = load_image(mask_image)
28
  self.target_width = target_width
29
  self.target_height = target_height
30
- self.pipeline.to(self.device)
31
- self.pipeline.unet = torch.compile(self.pipeline.unet,mode='max-autotune')
32
-
33
-
34
-
35
-
36
  def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
37
- """
38
- Runs the inference on the input image using the inpainting pipeline.
39
-
40
- Returns:
41
- Image: The output image after inpainting.
42
- """
43
-
44
- image = load_image(self.image)
45
- mask_image = load_image(self.mask_image)
46
- output = self.pipeline(prompt=prompt,negative_prompt=negative_prompt,image=image,mask_image=mask_image,num_inference_steps=num_inference_steps,strength=strength,guidance_scale=guidance_scale, height = self.target_height ,width = self.target_width).images[0]
 
 
47
  return output
48
-
49
-
50
- @hydra.main(version_base=None ,config_path="../configs", config_name="inpainting")
51
  def inference(cfg: DictConfig):
52
- """
53
- Load the configuration file for the inpainting pipeline.
54
-
55
- Args:
56
- cfg (DictConfig): The configuration file for the inpainting pipeline.
57
- """
58
  augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
59
- model_name = cfg.model
60
  image_path = "../sample_data/example3.jpg"
61
  image = Image.open(image_path)
62
  extended_image = augmenter.extend_image(image)
63
  mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
64
  mask_image = augmenter.invert_mask(mask_image)
65
- prompt = cfg.prompt
66
- negative_prompt = cfg.negative_prompt
67
- num_inference_steps = cfg.num_inference_steps
68
- strength = cfg.strength
69
- guidance_scale = cfg.guidance_scale
70
- pipeline = AutoPaintingPipeline(model_name=model_name, image = extended_image, mask_image=mask_image, target_height=cfg.target_height, target_width=cfg.target_width)
71
- output = pipeline.run_inference(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, strength=strength, guidance_scale=guidance_scale)
72
- output.save(f'{cfg.output_path}/output.jpg')
73
- mask_image.save(f'{cfg.output_path}/mask.jpg')
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
76
  if __name__ == "__main__":
77
  inference()
78
 
79
-
80
-
81
 
 
1
  import torch
2
+ from diffusers import AutoPipelineForInpainting
3
  from diffusers.utils import load_image
4
+ from utils import accelerator, ImageAugmentation
5
  import hydra
6
  from omegaconf import DictConfig
7
  from PIL import Image
8
  from functools import lru_cache
9
 
10
+ @lru_cache(maxsize=1)
11
+ def load_pipeline(model_name: str, device, enable_compile: bool = True):
12
+ pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
13
+ if enable_compile:
14
+ pipeline.unet.to(memory_format=torch.channels_last)
15
+ pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True)
16
+ pipeline.to(device)
17
+ return pipeline
18
 
19
  class AutoPaintingPipeline:
20
+ def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int):
21
+ self.pipeline = pipeline
22
+ self.image = image
23
+ self.mask_image = mask_image
 
 
 
 
 
 
 
 
 
 
 
24
  self.target_width = target_width
25
  self.target_height = target_height
26
+
 
 
 
 
 
27
  def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
28
+ output = self.pipeline(
29
+ prompt=prompt,
30
+ negative_prompt=negative_prompt,
31
+ image=self.image,
32
+ mask_image=self.mask_image,
33
+ num_inference_steps=num_inference_steps,
34
+ strength=strength,
35
+ guidance_scale=guidance_scale,
36
+ height=self.target_height,
37
+ width=self.target_width
38
+
39
+ ).images[0]
40
  return output
41
+
42
+ @hydra.main(version_base=None, config_path="../configs", config_name="inpainting")
 
43
  def inference(cfg: DictConfig):
44
+ # Load the pipeline once and cache it
45
+ pipeline = load_pipeline(cfg.model, accelerator(), True)
46
+
47
+ # Image augmentation and preparation
 
 
48
  augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
 
49
  image_path = "../sample_data/example3.jpg"
50
  image = Image.open(image_path)
51
  extended_image = augmenter.extend_image(image)
52
  mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
53
  mask_image = augmenter.invert_mask(mask_image)
 
 
 
 
 
 
 
 
 
54
 
55
+ # Create AutoPaintingPipeline instance with cached pipeline
56
+ painting_pipeline = AutoPaintingPipeline(
57
+ pipeline=pipeline,
58
+ image=extended_image,
59
+ mask_image=mask_image,
60
+ target_height=cfg.target_height,
61
+ target_width=cfg.target_width
62
+ )
63
+
64
+ # Run inference
65
+ output = painting_pipeline.run_inference(
66
+ prompt=cfg.prompt,
67
+ negative_prompt=cfg.negative_prompt,
68
+ num_inference_steps=cfg.num_inference_steps,
69
+ strength=cfg.strength,
70
+ guidance_scale=cfg.guidance_scale
71
+ )
72
 
73
+ # Save output and mask images
74
+ output.save(f'{cfg.output_path}/output.jpg')
75
+ mask_image.save(f'{cfg.output_path}/mask.jpg')
76
+
77
  if __name__ == "__main__":
78
  inference()
79
 
 
 
80