SIGMitch commited on
Commit
2153359
·
verified ·
1 Parent(s): dc67161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -9,7 +9,7 @@ import spaces
9
  import torch
10
  from PIL import Image
11
  from diffusers import FluxInpaintPipeline
12
-
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  IMAGE_SIZE = 1024
@@ -33,7 +33,8 @@ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
33
 
34
  pipe = FluxInpaintPipeline.from_pretrained(
35
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
36
-
 
37
 
38
  def resize_image_dimensions(
39
  original_resolution_wh: Tuple[int, int],
@@ -79,21 +80,34 @@ def process(
79
 
80
  if not image:
81
  gr.Info("Please upload an image.")
82
- return None, None
83
-
 
 
 
 
 
 
84
  if not mask:
85
  gr.Info("Please draw a mask on the image.")
86
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- width, height = resize_image_dimensions(original_resolution_wh=image.size)
89
- resized_image = image.resize((width, height), Image.LANCZOS)
90
  resized_mask = mask.resize((width, height), Image.LANCZOS)
91
 
92
- if randomize_seed_checkbox:
93
- seed_slicer = random.randint(0, MAX_SEED)
94
- generator = torch.Generator().manual_seed(seed_slicer)
95
- pipe.load_lora_weights("SIGMitch/KIT")
96
- result = pipe(
97
  prompt=input_text,
98
  image=resized_image,
99
  mask_image=resized_mask,
 
9
  import torch
10
  from PIL import Image
11
  from diffusers import FluxInpaintPipeline
12
+ from diffusers import FluxImg2ImgPipeline
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  IMAGE_SIZE = 1024
 
33
 
34
  pipe = FluxInpaintPipeline.from_pretrained(
35
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
36
+ pipe2 = FluxImg2ImgPipeline.from_pretrained(
37
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
38
 
39
  def resize_image_dimensions(
40
  original_resolution_wh: Tuple[int, int],
 
80
 
81
  if not image:
82
  gr.Info("Please upload an image.")
83
+ return result, None
84
+
85
+ width, height = resize_image_dimensions(original_resolution_wh=image.size)
86
+ resized_image = image.resize((width, height), Image.LANCZOS)
87
+ if randomize_seed_checkbox:
88
+ seed_slicer = random.randint(0, MAX_SEED)
89
+ generator = torch.Generator().manual_seed(seed_slicer)
90
+
91
  if not mask:
92
  gr.Info("Please draw a mask on the image.")
93
+ pipe2.load_lora_weights("SIGMitch/KIT")
94
+ result = pipe2(
95
+ prompt=input_text,
96
+ image=resized_image,
97
+ width=width,
98
+ height=height,
99
+ strength=strength_slider,
100
+ generator=generator,
101
+ joint_attention_kwargs={"scale": 1.2},
102
+ num_inference_steps=num_inference_steps_slider
103
+ ).images[0]
104
+ print('INFERENCE DONE')
105
+ return result, None
106
 
 
 
107
  resized_mask = mask.resize((width, height), Image.LANCZOS)
108
 
109
+ pipe.load_lora_weights("SIGMitch/KIT")
110
+ result = pipe(
 
 
 
111
  prompt=input_text,
112
  image=resized_image,
113
  mask_image=resized_mask,