Files changed (1) hide show
  1. app.py +130 -131
app.py CHANGED
@@ -42,6 +42,21 @@ bise_net = BiSeNet(n_classes = 19)
42
  bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu")) # device fail
43
  bise_net.cuda()
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ### Load consistentID_model checkpoint
46
  pipe.load_ConsistentID_model(
47
  os.path.dirname(consistentID_path),
@@ -58,52 +73,65 @@ pipe.image_encoder.to(device)
58
  pipe.image_proj_model.to(device)
59
  pipe.FacialEncoder.to(device)
60
 
61
- # @torch.inference_mode()
62
- # def Enhance_prompt(prompt,select_images):
63
-
64
- # llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
65
- # args = type('Args', (), {
66
- # "model_path": llva_model_path,
67
- # "model_base": None,
68
- # "model_name": get_model_name_from_path(llva_model_path),
69
- # "query": llva_prompt,
70
- # "conv_mode": None,
71
- # "image_file": select_images,
72
- # "sep": ",",
73
- # "temperature": 0,
74
- # "top_p": None,
75
- # "num_beams": 1,
76
- # "max_new_tokens": 512
77
- # })()
78
- # Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
79
-
80
- # return Enhanced_prompt
81
 
82
  @spaces.GPU
83
- def process(inputImage,prompt,negative_prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # hyper-parameter
86
- select_images = load_image(Image.fromarray(inputImage))
87
  num_steps = 50
88
- merge_steps = 30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
90
  if prompt == "":
91
  prompt = "A man, in a forest"
92
  prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
93
  prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
94
  else:
95
- # prompt=Enhance_prompt(prompt,blank_image) # TODO
96
- prompt=prompt
97
  print(prompt)
98
  pass
99
-
100
  if negative_prompt == "":
101
- negative_prompt = ",monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
102
 
103
- # Extend Prompt
104
  prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
105
 
106
- negtive_prompt_group="((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
107
  negative_prompt = negative_prompt + negtive_prompt_group
108
 
109
  seed = torch.randint(0, 1000, (1,)).item()
@@ -111,119 +139,90 @@ def process(inputImage,prompt,negative_prompt):
111
 
112
  images = pipe(
113
  prompt=prompt,
114
- width=512,
115
- height=768,
116
  input_id_images=select_images,
117
  negative_prompt=negative_prompt,
118
  num_images_per_prompt=1,
119
  num_inference_steps=num_steps,
120
  start_merge_step=merge_steps,
121
  generator=generator,
 
 
122
  ).images[0]
123
 
124
  current_date = datetime.today()
 
125
 
126
- output_dir = script_directory + f"/images/gradio_outputs"
127
- if not os.path.exists(output_dir):
128
- os.makedirs(output_dir)
129
-
130
- images.save(os.path.join(output_dir, f"{current_date}-{seed}.jpg"))
131
-
132
- return os.path.join(output_dir, f"{current_date}-{seed}.jpg")
133
-
134
- # # Gets the templates
135
- # script_directory = os.path.dirname(os.path.realpath(__file__))
136
- # # preset_template = glob.glob(script_directory+"/images/templates/*.png")
137
- # preset_template = glob.glob("./images/templates/*.png")
138
- # preset_template = preset_template + glob.glob("./images/templates/*.jpg")
139
-
140
- # # Use Blocks Create Gradio
141
- # with gr.Blocks(title="ConsistentID Demo") as demo:
142
- # gr.Markdown("# ConsistentID Demo")
143
- # gr.Markdown("\
144
- # Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
145
- # gr.Markdown("\
146
- # If you find our work interesting, please leave a star in GitHub for us!<br>\
147
- # https://github.com/JackAILab/ConsistentID")
148
- # with gr.Row():
149
- # with gr.Column():
150
- # model_selected_tab = gr.State(0)
151
- # with gr.TabItem("template images") as template_images_tab:
152
- # template_gallery_list = [(i, i) for i in preset_template]
153
- # gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
154
 
155
- # def select_function(evt: gr.SelectData):
156
- # return preset_template[evt.index]
157
-
158
- # selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
159
- # gallery.select(select_function, None, selected_template_images)
160
- # with gr.TabItem("Upload Image") as upload_image_tab:
161
- # costum_image = gr.Image(label="Upload Image")
162
-
163
- # model_selected_tabs = [template_images_tab, upload_image_tab]
164
- # for i, tab in enumerate(model_selected_tabs):
165
- # tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
166
-
167
- # with gr.Column():
168
- # prompt_selected_tab = gr.State(0)
169
- # with gr.TabItem("template prompts") as template_prompts_tab:
170
- # prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
171
- # "A woman in a wedding dress",
172
- # "A woman, queen, in a gorgeous palace",
173
- # "A man sitting at the beach with sunset",
174
- # "A person, police officer, half body shot",
175
- # "A man, sailor, in a boat above ocean",
176
- # "A women wearing headphone, listening music",
177
- # "A man, firefighter, half body shot"], label=f"prepared prompts")
178
-
179
- # with gr.TabItem("custom prompt") as custom_prompt_tab:
180
- # prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
181
- # nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
182
 
183
- # prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
184
- # for i, tab in enumerate(prompt_selected_tabs):
185
- # tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
186
 
187
- # retouching = gr.Checkbox(label="face retouching",value=False)
188
- # width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
189
- # height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
190
- # width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
191
- # height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
192
- # merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
193
 
194
- # btn = gr.Button("Run")
195
- # with gr.Column():
196
- # out = gr.Image(label="Output")
197
- # gr.Markdown('''
198
- # N.B.:<br/>
199
- # - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
200
- # - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
201
- # - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
202
- # ''')
203
- # btn.click(fn=process, inputs=[selected_template_images,costum_image,prompt,nagetive_prompt,prompt_selected,retouching
204
- # ,model_selected_tab,prompt_selected_tab,width,height,merge_steps], outputs=out)
205
-
206
- iface = gr.Interface(
207
- fn=process,
208
- inputs=[
209
- gr.Image(label="Upload Image"),
210
- gr.Textbox(label="prompt",placeholder="A man, in a forest, adventuring"),
211
- gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry"),
212
- ],
213
- outputs=[
214
- gr.Image(label="Output"),
215
- ],
216
- title="ConsistentID Demo",
217
- description="Put reference portrait below" ,
218
- allow_flagging="never"
219
- )
220
-
221
- iface.launch() # zero.device
222
-
223
- # @spaces.GPU
224
- # def greet(n):
225
- # print(zero.device) # <-- 'cuda:0' 🤗
226
- # return f"Hello {zero + n} Tensor"
227
-
228
- # demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
229
- # demo.launch()
 
42
  bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu")) # device fail
43
  bise_net.cuda()
44
 
45
+ sys.path.append("./models/LLaVA")
46
+ from llava.model.builder import load_pretrained_model
47
+ from llava.mm_utils import get_model_name_from_path
48
+ from llava.eval.run_llava import eval_model
49
+
50
+ # Load Llava for prompt enhancement
51
+ llva_model_path = "liuhaotian/llava-v1.5-7b"
52
+ llva_tokenizer, llva_model, llva_image_processor, llva_context_len = load_pretrained_model(
53
+ model_path=llva_model_path,
54
+ model_base=None,
55
+ model_name=get_model_name_from_path(llva_model_path),)
56
+ llva_tokenizer.cuda()
57
+ llva_model.cuda()
58
+ llva_image_processor.cuda()
59
+
60
  ### Load consistentID_model checkpoint
61
  pipe.load_ConsistentID_model(
62
  os.path.dirname(consistentID_path),
 
73
  pipe.image_proj_model.to(device)
74
  pipe.FacialEncoder.to(device)
75
 
76
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @spaces.GPU
79
+ def process(selected_template_images,costum_image,prompt
80
+ ,negative_prompt,prompt_selected,retouching,model_selected_tab,prompt_selected_tab,width,height,merge_steps):
81
+
82
+ if model_selected_tab==0:
83
+ select_images = load_image(Image.open(selected_template_images))
84
+ else:
85
+ select_images = load_image(Image.fromarray(costum_image))
86
+
87
+ if prompt_selected_tab==0:
88
+ prompt = prompt_selected
89
+ negative_prompt = ""
90
+ need_safetycheck = False
91
+ else:
92
+ need_safetycheck = True
93
 
94
  # hyper-parameter
 
95
  num_steps = 50
96
+ # merge_steps = 30
97
+
98
+ @torch.inference_mode()
99
+ def Enhance_prompt(prompt,select_images):
100
+
101
+ llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
102
+ args = type('Args', (), {
103
+ "model_path": llva_model_path,
104
+ "model_base": None,
105
+ "model_name": get_model_name_from_path(llva_model_path),
106
+ "query": llva_prompt,
107
+ "conv_mode": None,
108
+ "image_file": select_images,
109
+ "sep": ",",
110
+ "temperature": 0,
111
+ "top_p": None,
112
+ "num_beams": 1,
113
+ "max_new_tokens": 512
114
+ })()
115
+ Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
116
 
117
+ return Enhanced_prompt
118
+
119
  if prompt == "":
120
  prompt = "A man, in a forest"
121
  prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
122
  prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
123
  else:
124
+ prompt=Enhance_prompt(prompt,Image.new('RGB', (200, 200), color = 'white'))
 
125
  print(prompt)
126
  pass
127
+
128
  if negative_prompt == "":
129
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
130
 
131
+ #Extend Prompt
132
  prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
133
 
134
+ negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
135
  negative_prompt = negative_prompt + negtive_prompt_group
136
 
137
  seed = torch.randint(0, 1000, (1,)).item()
 
139
 
140
  images = pipe(
141
  prompt=prompt,
142
+ width=width,
143
+ height=height,
144
  input_id_images=select_images,
145
  negative_prompt=negative_prompt,
146
  num_images_per_prompt=1,
147
  num_inference_steps=num_steps,
148
  start_merge_step=merge_steps,
149
  generator=generator,
150
+ retouching=retouching,
151
+ need_safetycheck=need_safetycheck,
152
  ).images[0]
153
 
154
  current_date = datetime.today()
155
+ return np.array(images)
156
 
157
+ # Gets the templates
158
+ script_directory = os.path.dirname(os.path.realpath(__file__))
159
+ preset_template = glob.glob("./images/templates/*.png")
160
+ preset_template = preset_template + glob.glob("./images/templates/*.jpg")
161
+
162
+
163
+ with gr.Blocks(title="ConsistentID Demo") as demo:
164
+ gr.Markdown("# ConsistentID Demo")
165
+ gr.Markdown("\
166
+ Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
167
+ gr.Markdown("\
168
+ If you find our work interesting, please leave a star in GitHub for us!<br>\
169
+ https://github.com/JackAILab/ConsistentID")
170
+ with gr.Row():
171
+ with gr.Column():
172
+ model_selected_tab = gr.State(0)
173
+ with gr.TabItem("template images") as template_images_tab:
174
+ template_gallery_list = [(i, i) for i in preset_template]
175
+ gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
 
 
 
 
 
 
 
 
 
176
 
177
+ def select_function(evt: gr.SelectData):
178
+ return preset_template[evt.index]
179
+
180
+ selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
181
+ gallery.select(select_function, None, selected_template_images)
182
+ with gr.TabItem("Upload Image") as upload_image_tab:
183
+ costum_image = gr.Image(label="Upload Image")
184
+
185
+ model_selected_tabs = [template_images_tab, upload_image_tab]
186
+ for i, tab in enumerate(model_selected_tabs):
187
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
188
+
189
+ with gr.Column():
190
+ prompt_selected_tab = gr.State(0)
191
+ with gr.TabItem("template prompts") as template_prompts_tab:
192
+ prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
193
+ "A woman in a wedding dress",
194
+ "A woman, queen, in a gorgeous palace",
195
+ "A man sitting at the beach with sunset",
196
+ "A person, police officer, half body shot",
197
+ "A man, sailor, in a boat above ocean",
198
+ "A women wearing headphone, listening music",
199
+ "A man, firefighter, half body shot"], label=f"prepared prompts")
200
+
201
+ with gr.TabItem("custom prompt") as custom_prompt_tab:
202
+ prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
203
+ nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
204
 
205
+ prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
206
+ for i, tab in enumerate(prompt_selected_tabs):
207
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
208
 
209
+ retouching = gr.Checkbox(label="face retouching",value=False)
210
+ width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
211
+ height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
212
+ width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
213
+ height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
214
+ merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
215
 
216
+ btn = gr.Button("Run")
217
+ with gr.Column():
218
+ out = gr.Image(label="Output")
219
+ gr.Markdown('''
220
+ N.B.:<br/>
221
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
222
+ - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
223
+ - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
224
+ ''')
225
+ btn.click(fn=process, inputs=[selected_template_images,costum_image,prompt,nagetive_prompt,prompt_selected,retouching
226
+ ,model_selected_tab,prompt_selected_tab,width,height,merge_steps], outputs=out)
227
+
228
+ demo.launch()