fountai commited on
Commit
2645a2e
·
1 Parent(s): 6da1508
Files changed (2) hide show
  1. app.py +152 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ def download_file(url, output_filename):
4
+ command = ['wget', '-O', output_filename, '-q', url]
5
+ subprocess.run(command, check=True)
6
+
7
+ url1 = 'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite'
8
+ url2 = 'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_segmenter/float16/latest/selfie_segmenter.tflite'
9
+
10
+ filename1 = 'selfie_multiclass_256x256.tflite'
11
+ filename2 = 'selfie_segmenter.tflite'
12
+
13
+ download_file(url1, filename1)
14
+ download_file(url2, filename2)
15
+
16
+ import cv2
17
+ import mediapipe as mp
18
+ import numpy as np
19
+ from mediapipe.tasks import python
20
+ from mediapipe.tasks.python import vision
21
+ import random
22
+ import gradio as gr
23
+ import spaces
24
+ import torch
25
+ from diffusers import FluxInpaintPipeline
26
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
27
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
28
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
29
+
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ bfl_repo="black-forest-labs/FLUX.1-dev"
33
+
34
+ BG_COLOR = (0, 0, 0) # black
35
+ MASK_COLOR = (255, 255, 255) # white
36
+
37
+ def maskHead(input):
38
+ base_options = python.BaseOptions(model_asset_path='selfie_multiclass_256x256.tflite')
39
+ options = vision.ImageSegmenterOptions(base_options=base_options,
40
+ output_category_mask=True)
41
+
42
+ with vision.ImageSegmenter.create_from_options(options) as segmenter:
43
+ image = mp.Image.create_from_file(input)
44
+
45
+ segmentation_result = segmenter.segment(image)
46
+
47
+ hairmask = segmentation_result.confidence_masks[1]
48
+ facemask = segmentation_result.confidence_masks[3]
49
+
50
+ image_data = image.numpy_view()
51
+ fg_image = np.zeros(image_data.shape, dtype=np.uint8)
52
+ fg_image[:] = MASK_COLOR
53
+ bg_image = np.zeros(image_data.shape, dtype=np.uint8)
54
+ bg_image[:] = BG_COLOR
55
+
56
+ combined_mask = np.maximum(hairmask.numpy_view(), facemask.numpy_view())
57
+
58
+ condition = np.stack((combined_mask,) * 3, axis=-1) > 0.2
59
+ output_image = np.where(condition, fg_image, bg_image)
60
+
61
+ return output_image
62
+
63
+ def random_positioning(input, output_size=(1024, 1024)):
64
+ if input is None:
65
+ raise ValueError("Impossible to load image")
66
+
67
+ scale_factor = random.uniform(0.5, 1.0)
68
+
69
+ new_size = (int(input.shape[1] * scale_factor), int(input.shape[0] * scale_factor))
70
+
71
+ resized_image = cv2.resize(input, new_size, interpolation=cv2.INTER_AREA)
72
+
73
+ background = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8)
74
+
75
+ x_offset = random.randint(0, output_size[0] - new_size[0])
76
+ y_offset = random.randint(0, output_size[1] - new_size[1])
77
+
78
+ background[y_offset:y_offset+new_size[1], x_offset:x_offset+new_size[0]] = resized_image
79
+ background = np.clip(background, 0, 255)
80
+ background = background.astype(np.uint8)
81
+
82
+ return background
83
+
84
+
85
+ def remove_background(image_path, mask):
86
+ image = cv2.imread(image_path)
87
+ inverted_mask = cv2.bitwise_not(mask)
88
+
89
+ _, binary_mask = cv2.threshold(inverted_mask, 127, 255, cv2.THRESH_BINARY)
90
+
91
+ result = np.zeros_like(image, dtype=np.uint8)
92
+
93
+ result[binary_mask == 255] = image[binary_mask == 255]
94
+
95
+ return result
96
+
97
+ pipe = FluxInpaintPipeline.from_pretrained(bfl_repo, torch_dtype=torch.bfloat16).to(DEVICE)
98
+ MAX_SEED = np.iinfo(np.int32).max
99
+ TRIGGER = "a photo of TOK"
100
+
101
+ @spaces.GPU(duration=150)
102
+ def execute(image, prompt):
103
+ if not prompt :
104
+ gr.Info("Please enter a text prompt.")
105
+ return None
106
+
107
+ if not image :
108
+ gr.Info("Please upload a image.")
109
+ return None
110
+
111
+ img = cv2.imread(image)
112
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
113
+
114
+ imgs = [ random_positioning(img), random_positioning(img), random_positioning(img), random_positioning(img)]
115
+
116
+ pipe.load_lora_weights("XLabs-AI/flux-RealismLora", weight_name='lora.safetensors')
117
+ response = []
118
+
119
+ seed_slicer = random.randint(0, MAX_SEED)
120
+ generator = torch.Generator().manual_seed(seed_slicer)
121
+
122
+ for image in range(len(imgs)):
123
+ current_img = imgs[image]
124
+ cv2.imwrite('base_image.jpg', current_img)
125
+ mask = maskHead('base_image.jpg')
126
+ result = pipe(
127
+ prompt=f"{prompt} {TRIGGER}",
128
+ image=current_img,
129
+ mask_image=mask,
130
+ width=1024,
131
+ height=1024,
132
+ strength=0.85,
133
+ generator=generator,
134
+ num_inference_steps=28,
135
+ max_sequence_length=256,
136
+ joint_attention_kwargs={"scale": 0.9},
137
+ ).images[0]
138
+ response.append(result)
139
+
140
+ return response
141
+
142
+ iface = gr.Interface(
143
+ fn=execute,
144
+ inputs=[
145
+ gr.Image(type="filepath"),
146
+ gr.Textbox(label="Prompt")
147
+ ],
148
+ outputs="gallery"
149
+ )
150
+
151
+ iface.launch(share=True, debug=True)
152
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ mediapipe
2
+ diffusers
3
+ transformers