Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import spaces
|
|
9 |
import torch
|
10 |
from PIL import Image
|
11 |
from diffusers import FluxInpaintPipeline
|
12 |
-
|
13 |
|
14 |
MAX_SEED = np.iinfo(np.int32).max
|
15 |
IMAGE_SIZE = 1024
|
@@ -33,7 +33,8 @@ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
|
33 |
|
34 |
pipe = FluxInpaintPipeline.from_pretrained(
|
35 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
36 |
-
|
|
|
37 |
|
38 |
def resize_image_dimensions(
|
39 |
original_resolution_wh: Tuple[int, int],
|
@@ -79,21 +80,34 @@ def process(
|
|
79 |
|
80 |
if not image:
|
81 |
gr.Info("Please upload an image.")
|
82 |
-
return
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if not mask:
|
85 |
gr.Info("Please draw a mask on the image.")
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
89 |
-
resized_image = image.resize((width, height), Image.LANCZOS)
|
90 |
resized_mask = mask.resize((width, height), Image.LANCZOS)
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
generator = torch.Generator().manual_seed(seed_slicer)
|
95 |
-
pipe.load_lora_weights("SIGMitch/KIT")
|
96 |
-
result = pipe(
|
97 |
prompt=input_text,
|
98 |
image=resized_image,
|
99 |
mask_image=resized_mask,
|
|
|
9 |
import torch
|
10 |
from PIL import Image
|
11 |
from diffusers import FluxInpaintPipeline
|
12 |
+
from diffusers import FluxImg2ImgPipeline
|
13 |
|
14 |
MAX_SEED = np.iinfo(np.int32).max
|
15 |
IMAGE_SIZE = 1024
|
|
|
33 |
|
34 |
pipe = FluxInpaintPipeline.from_pretrained(
|
35 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
36 |
+
pipe2 = FluxImg2ImgPipeline.from_pretrained(
|
37 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
38 |
|
39 |
def resize_image_dimensions(
|
40 |
original_resolution_wh: Tuple[int, int],
|
|
|
80 |
|
81 |
if not image:
|
82 |
gr.Info("Please upload an image.")
|
83 |
+
return result, None
|
84 |
+
|
85 |
+
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
86 |
+
resized_image = image.resize((width, height), Image.LANCZOS)
|
87 |
+
if randomize_seed_checkbox:
|
88 |
+
seed_slicer = random.randint(0, MAX_SEED)
|
89 |
+
generator = torch.Generator().manual_seed(seed_slicer)
|
90 |
+
|
91 |
if not mask:
|
92 |
gr.Info("Please draw a mask on the image.")
|
93 |
+
pipe2.load_lora_weights("SIGMitch/KIT")
|
94 |
+
result = pipe2(
|
95 |
+
prompt=input_text,
|
96 |
+
image=resized_image,
|
97 |
+
width=width,
|
98 |
+
height=height,
|
99 |
+
strength=strength_slider,
|
100 |
+
generator=generator,
|
101 |
+
joint_attention_kwargs={"scale": 1.2},
|
102 |
+
num_inference_steps=num_inference_steps_slider
|
103 |
+
).images[0]
|
104 |
+
print('INFERENCE DONE')
|
105 |
+
return result, None
|
106 |
|
|
|
|
|
107 |
resized_mask = mask.resize((width, height), Image.LANCZOS)
|
108 |
|
109 |
+
pipe.load_lora_weights("SIGMitch/KIT")
|
110 |
+
result = pipe(
|
|
|
|
|
|
|
111 |
prompt=input_text,
|
112 |
image=resized_image,
|
113 |
mask_image=resized_mask,
|