shgao commited on
Commit
2474e74
β€’
1 Parent(s): 03a83ad

update new versiion

Browse files
app.py CHANGED
@@ -15,9 +15,7 @@ SHARED_UI_WARNING = f'''### [NOTE] Inference may be slow in this shared UI.
15
  You can duplicate and use it with a paid private GPU.
16
  <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
17
  '''
18
-
19
- #
20
- sam_generator = init_sam_model()
21
  blip_processor = init_blip_processor()
22
  blip_model = init_blip_model()
23
 
@@ -31,30 +29,33 @@ with gr.Blocks() as demo:
31
  controlmodel_name='LAION Pretrained(v0-4)-SD21',
32
  lora_model_path=None, use_blip=True, extra_inpaint=False,
33
  sam_generator=sam_generator,
 
34
  blip_processor=blip_processor,
35
  blip_model=blip_model)
36
- create_demo_edit_anything(model.process)
37
  with gr.TabItem(' πŸ‘©β€πŸ¦°Beauty Edit/Generation'):
38
  lora_model_path = hf_hub_download(
39
  "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
40
  model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
41
  lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
42
  sam_generator=sam_generator,
 
43
  blip_processor=blip_processor,
44
  blip_model=blip_model,
45
  lora_weight=0.5,
46
  )
47
- create_demo_beauty(model.process)
48
- # with gr.TabItem(' πŸ‘¨β€πŸŒΎHandsome Edit/Generation'):
49
- # model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "Realistic_Vision_V2.0"),
50
- # lora_model_path=None, use_blip=True, extra_inpaint=True,
51
- # sam_generator=sam_generator,
52
- # blip_processor=blip_processor,
53
- # blip_model=blip_model)
54
- # create_demo_handsome(model.process)
 
55
  # with gr.TabItem('Generate Anything'):
56
  # create_demo_generate_anything()
57
  with gr.Tabs():
58
  gr.Markdown(SHARED_UI_WARNING)
59
 
60
- demo.queue(api_open=False).launch()
 
15
  You can duplicate and use it with a paid private GPU.
16
  <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
17
  '''
18
+ sam_generator, mask_predictor = init_sam_model()
 
 
19
  blip_processor = init_blip_processor()
20
  blip_model = init_blip_model()
21
 
 
29
  controlmodel_name='LAION Pretrained(v0-4)-SD21',
30
  lora_model_path=None, use_blip=True, extra_inpaint=False,
31
  sam_generator=sam_generator,
32
+ mask_predictor=mask_predictor,
33
  blip_processor=blip_processor,
34
  blip_model=blip_model)
35
+ create_demo_edit_anything(model.process, model.process_image_click)
36
  with gr.TabItem(' πŸ‘©β€πŸ¦°Beauty Edit/Generation'):
37
  lora_model_path = hf_hub_download(
38
  "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
39
  model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
40
  lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
41
  sam_generator=sam_generator,
42
+ mask_predictor=mask_predictor,
43
  blip_processor=blip_processor,
44
  blip_model=blip_model,
45
  lora_weight=0.5,
46
  )
47
+ create_demo_beauty(model.process, model.process_image_click)
48
+ with gr.TabItem(' πŸ‘¨β€πŸŒΎHandsome Edit/Generation'):
49
+ model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "Realistic_Vision_V2.0"),
50
+ lora_model_path=None, use_blip=True, extra_inpaint=True,
51
+ sam_generator=sam_generator,
52
+ mask_predictor=mask_predictor,
53
+ blip_processor=blip_processor,
54
+ blip_model=blip_model)
55
+ create_demo_handsome(model.process, model.process_image_click)
56
  # with gr.TabItem('Generate Anything'):
57
  # create_demo_generate_anything()
58
  with gr.Tabs():
59
  gr.Markdown(SHARED_UI_WARNING)
60
 
61
+ demo.queue(api_open=False).launch(server_name='0.0.0.0', share=False)
requirements.txt CHANGED
@@ -3,7 +3,7 @@ torch==1.13.1+cu117
3
  torchvision==0.14.1+cu117
4
  torchaudio==0.13.1
5
  numpy==1.23.1
6
- gradio==3.25.0
7
  gradio_client==0.1.4
8
  albumentations==1.3.0
9
  opencv-contrib-python==4.3.0.36
 
3
  torchvision==0.14.1+cu117
4
  torchaudio==0.13.1
5
  numpy==1.23.1
6
+ gradio==3.30.0
7
  gradio_client==0.1.4
8
  albumentations==1.3.0
9
  opencv-contrib-python==4.3.0.36
sam2edit.py CHANGED
@@ -1,82 +1,28 @@
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
 
2
  import gradio as gr
3
  from diffusers.utils import load_image
4
  from sam2edit_lora import EditAnythingLoraModel, config_dict
 
 
5
 
6
 
7
- def create_demo(process):
8
 
9
-
10
-
11
- print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
12
- WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
13
- We are not responsible for possible risks using this model.
14
  '''
15
- block = gr.Blocks()
16
- with block as demo:
17
- with gr.Row():
18
- gr.Markdown(
19
- "## EditAnything https://github.com/sail-sg/EditAnything ")
20
- with gr.Row():
21
- with gr.Column():
22
- source_image = gr.Image(
23
- source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
24
- enable_all_generate = gr.Checkbox(
25
- label='Auto generation on all region.', value=False)
26
- prompt = gr.Textbox(
27
- label="Prompt (Text in the expected things of edited region)")
28
- enable_auto_prompt = gr.Checkbox(
29
- label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
30
- a_prompt = gr.Textbox(
31
- label="Added Prompt", value='best quality, extremely detailed')
32
- n_prompt = gr.Textbox(label="Negative Prompt",
33
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
34
- control_scale = gr.Slider(
35
- label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
36
- run_button = gr.Button(label="Run")
37
- num_samples = gr.Slider(
38
- label="Images", minimum=1, maximum=12, value=2, step=1)
39
- seed = gr.Slider(label="Seed", minimum=-1,
40
- maximum=2147483647, step=1, randomize=True)
41
- enable_tile = gr.Checkbox(
42
- label='Tile refinement for high resolution generation.', value=True)
43
- with gr.Accordion("Advanced options", open=False):
44
- mask_image = gr.Image(
45
- source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
46
- image_resolution = gr.Slider(
47
- label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
48
- strength = gr.Slider(
49
- label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
50
- guess_mode = gr.Checkbox(
51
- label='Guess Mode', value=False)
52
- detect_resolution = gr.Slider(
53
- label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
54
- ddim_steps = gr.Slider(
55
- label="Steps", minimum=1, maximum=100, value=30, step=1)
56
- scale = gr.Slider(
57
- label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
58
- eta = gr.Number(label="eta (DDIM)", value=0.0)
59
- with gr.Column():
60
- result_gallery = gr.Gallery(
61
- label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
62
- result_text = gr.Text(label='BLIP2+Human Prompt Text')
63
- ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
64
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
65
- run_button.click(fn=process, inputs=ips, outputs=[
66
- result_gallery, result_text])
67
- # with gr.Row():
68
- # ex = gr.Examples(examples=examples, fn=process,
69
- # inputs=[a_prompt, n_prompt, scale],
70
- # outputs=[result_gallery],
71
- # cache_examples=False)
72
- with gr.Row():
73
- gr.Markdown(WARNING_INFO)
74
  return demo
75
 
76
 
77
  if __name__ == '__main__':
78
- model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
79
- controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=False,
80
  lora_model_path=None, use_blip=True)
81
- demo = create_demo(model.process)
82
  demo.queue().launch(server_name='0.0.0.0')
 
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
  from sam2edit_lora import EditAnythingLoraModel, config_dict
6
+ from sam2edit_demo import create_demo_template
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
+ def create_demo(process, process_image_click=None):
11
 
12
+ examples = None
13
+ INFO = f'''
14
+ ## EditAnything https://github.com/sail-sg/EditAnything
 
 
15
  '''
16
+ WARNING_INFO = None
17
+
18
+ demo = create_demo_template(process, process_image_click, examples=examples,
19
+ INFO=INFO, WARNING_INFO=WARNING_INFO, enable_auto_prompt_default=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return demo
21
 
22
 
23
  if __name__ == '__main__':
24
+ model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2",
25
+ controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=True,
26
  lora_model_path=None, use_blip=True)
27
+ demo = create_demo(model.process, model.process_image_click)
28
  demo.queue().launch(server_name='0.0.0.0')
sam2edit_beauty.py CHANGED
@@ -1,10 +1,13 @@
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
 
2
  import gradio as gr
3
  from diffusers.utils import load_image
4
  from sam2edit_lora import EditAnythingLoraModel, config_dict
 
 
5
 
6
 
7
- def create_demo(process):
8
 
9
  examples = [
10
  ["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
@@ -16,77 +19,26 @@ def create_demo(process):
16
  ["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
17
  "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
18
  ]
19
-
20
- print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
 
 
21
  WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
22
  We are not responsible for possible risks using this model.
23
-
24
  Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
25
  '''
26
- block = gr.Blocks()
27
- with block as demo:
28
- with gr.Row():
29
- gr.Markdown(
30
- "## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
31
- with gr.Row():
32
- with gr.Column():
33
- source_image = gr.Image(
34
- source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
35
- enable_all_generate = gr.Checkbox(
36
- label='Auto generation on all region.', value=False)
37
- prompt = gr.Textbox(
38
- label="Prompt (Text in the expected things of edited region)")
39
- enable_auto_prompt = gr.Checkbox(
40
- label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
41
- a_prompt = gr.Textbox(
42
- label="Added Prompt", value='best quality, extremely detailed')
43
- n_prompt = gr.Textbox(label="Negative Prompt",
44
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
45
- control_scale = gr.Slider(
46
- label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
47
- run_button = gr.Button(label="Run")
48
- num_samples = gr.Slider(
49
- label="Images", minimum=1, maximum=12, value=2, step=1)
50
- seed = gr.Slider(label="Seed", minimum=-1,
51
- maximum=2147483647, step=1, randomize=True)
52
- enable_tile = gr.Checkbox(
53
- label='Tile refinement for high resolution generation.', value=True)
54
- with gr.Accordion("Advanced options", open=False):
55
- mask_image = gr.Image(
56
- source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
57
- image_resolution = gr.Slider(
58
- label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
59
- strength = gr.Slider(
60
- label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
61
- guess_mode = gr.Checkbox(
62
- label='Guess Mode', value=False)
63
- detect_resolution = gr.Slider(
64
- label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
65
- ddim_steps = gr.Slider(
66
- label="Steps", minimum=1, maximum=100, value=30, step=1)
67
- scale = gr.Slider(
68
- label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
69
- eta = gr.Number(label="eta (DDIM)", value=0.0)
70
- with gr.Column():
71
- result_gallery = gr.Gallery(
72
- label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
73
- result_text = gr.Text(label='BLIP2+Human Prompt Text')
74
- ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
75
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
76
- run_button.click(fn=process, inputs=ips, outputs=[
77
- result_gallery, result_text])
78
- with gr.Row():
79
- ex = gr.Examples(examples=examples, fn=process,
80
- inputs=[a_prompt, n_prompt, scale],
81
- outputs=[result_gallery],
82
- cache_examples=False)
83
- with gr.Row():
84
- gr.Markdown(WARNING_INFO)
85
  return demo
86
 
87
 
88
  if __name__ == '__main__':
89
- model = EditAnythingLoraModel(base_model_path='../chilloutmix_NiPrunedFp32Fix',
90
- lora_model_path='../40806/mix4', use_blip=True, lora_weight=0.5)
91
- demo = create_demo(model.process)
 
 
 
 
 
92
  demo.queue().launch(server_name='0.0.0.0')
 
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
  from sam2edit_lora import EditAnythingLoraModel, config_dict
6
+ from sam2edit_demo import create_demo_template
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
+ def create_demo(process, process_image_click=None):
11
 
12
  examples = [
13
  ["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
 
19
  ["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
20
  "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
21
  ]
22
+ INFO = f'''
23
+ ## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything
24
+ This model is good at generating beautiful female.
25
+ '''
26
  WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
27
  We are not responsible for possible risks using this model.
 
28
  Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
29
  '''
30
+ demo = create_demo_template(process, process_image_click,
31
+ examples=examples, INFO=INFO, WARNING_INFO=WARNING_INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return demo
33
 
34
 
35
  if __name__ == '__main__':
36
+ sd_models_path = snapshot_download("shgao/sdmodels")
37
+ lora_model_path = hf_hub_download(
38
+ "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
39
+ model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
40
+ lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
41
+ lora_weight=0.5,
42
+ )
43
+ demo = create_demo(model.process, model.process_image_click)
44
  demo.queue().launch(server_name='0.0.0.0')
sam2edit_demo.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import gradio as gr
3
+
4
+ def create_demo_template(process, process_image_click=None, examples=None,
5
+ INFO='EditAnything https://github.com/sail-sg/EditAnything', WARNING_INFO=None,
6
+ enable_auto_prompt_default=False,
7
+ ):
8
+
9
+ print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
10
+ block = gr.Blocks()
11
+ with block as demo:
12
+ clicked_points = gr.State([])
13
+ origin_image = gr.State(None)
14
+ click_mask = gr.State(None)
15
+ with gr.Row():
16
+ gr.Markdown(INFO)
17
+ with gr.Row().style(equal_height=False):
18
+ with gr.Column():
19
+ with gr.Tab("ClickπŸ–±"):
20
+ source_image_click = gr.Image(
21
+ type="pil", interactive=True,
22
+ label="Image: Upload an image and click the region you want to edit.",
23
+ )
24
+ with gr.Column():
25
+ with gr.Row():
26
+ point_prompt = gr.Radio(
27
+ choices=["Foreground Point", "Background Point"],
28
+ value="Foreground Point",
29
+ label="Point Label",
30
+ interactive=True, show_label=False)
31
+ clear_button_click = gr.Button(
32
+ value="Clear Click Points", interactive=True)
33
+ clear_button_image = gr.Button(
34
+ value="Clear Image", interactive=True)
35
+ with gr.Row():
36
+ run_button_click = gr.Button(
37
+ label="Run EditAnying", interactive=True)
38
+ with gr.Tab("BrushπŸ–ŒοΈ"):
39
+ source_image_brush = gr.Image(
40
+ source='upload',
41
+ label="Image: Upload an image and cover the region you want to edit with sketch",
42
+ type="numpy", tool="sketch"
43
+ )
44
+ run_button = gr.Button(label="Run EditAnying", interactive=True)
45
+ with gr.Column():
46
+ enable_all_generate = gr.Checkbox(
47
+ label='Auto generation on all region.', value=False)
48
+ control_scale = gr.Slider(
49
+ label="Mask Align strength", info="Large value -> strict alignment with SAM mask", minimum=0, maximum=1, value=1, step=0.1)
50
+ with gr.Column():
51
+ enable_auto_prompt = gr.Checkbox(
52
+ label='Auto generate text prompt from input image with BLIP2', info='Warning: Enable this may makes your prompt not working.', value=enable_auto_prompt_default)
53
+ a_prompt = gr.Textbox(
54
+ label="Positive Prompt", info='Text in the expected things of edited region', value='best quality, extremely detailed')
55
+ n_prompt = gr.Textbox(label="Negative Prompt",
56
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW')
57
+ with gr.Row():
58
+ num_samples = gr.Slider(
59
+ label="Images", minimum=1, maximum=12, value=2, step=1)
60
+ seed = gr.Slider(label="Seed", minimum=-1,
61
+ maximum=2147483647, step=1, randomize=True)
62
+ with gr.Row():
63
+ enable_tile = gr.Checkbox(
64
+ label='Tile refinement for high resolution generation', info='Slow inference', value=True)
65
+ refine_alignment_ratio = gr.Slider(
66
+ label="Alignment Strength", info='Large value -> strict alignment with input image. Small value -> strong global consistency', minimum=0.0, maximum=1.0, value=0.95, step=0.05)
67
+
68
+ with gr.Accordion("Advanced options", open=False):
69
+ mask_image = gr.Image(
70
+ source='upload', label="Upload a predefined mask of edit region if you do not want to write your prompt.", info="(Optional:Switch to Brush mode when using this!) ", type="numpy", value=None)
71
+ image_resolution = gr.Slider(
72
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
73
+ refine_image_resolution = gr.Slider(
74
+ label="Image Resolution", minimum=256, maximum=8192, value=1024, step=64)
75
+ strength = gr.Slider(
76
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
77
+ guess_mode = gr.Checkbox(
78
+ label='Guess Mode', value=False)
79
+ detect_resolution = gr.Slider(
80
+ label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
81
+ ddim_steps = gr.Slider(
82
+ label="Steps", minimum=1, maximum=100, value=30, step=1)
83
+ scale = gr.Slider(
84
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
85
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
86
+ with gr.Column():
87
+ result_gallery_refine = gr.Gallery(
88
+ label='Output High quality', show_label=True, elem_id="gallery").style(grid=2, preview=False)
89
+ result_gallery_init = gr.Gallery(
90
+ label='Output Low quality', show_label=True, elem_id="gallery").style(grid=2, height='auto')
91
+ result_gallery_ref = gr.Gallery(
92
+ label='Output Ref', show_label=False, elem_id="gallery").style(grid=2, height='auto')
93
+ result_text = gr.Text(label='BLIP2+Human Prompt Text')
94
+
95
+ ips = [source_image_brush, enable_all_generate, mask_image, control_scale, enable_auto_prompt, a_prompt, n_prompt, num_samples, image_resolution,
96
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile, refine_alignment_ratio, refine_image_resolution]
97
+ run_button.click(fn=process, inputs=ips, outputs=[
98
+ result_gallery_refine, result_gallery_init, result_gallery_ref, result_text])
99
+
100
+ ip_click = [origin_image, enable_all_generate, click_mask, control_scale, enable_auto_prompt, a_prompt, n_prompt, num_samples, image_resolution,
101
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile, refine_alignment_ratio, refine_image_resolution]
102
+
103
+ run_button_click.click(fn=process,
104
+ inputs=ip_click,
105
+ outputs=[result_gallery_refine, result_gallery_init, result_gallery_ref, result_text])
106
+
107
+ source_image_click.upload(
108
+ lambda image: image.copy() if image is not None else None,
109
+ inputs=[source_image_click],
110
+ outputs=[origin_image]
111
+ )
112
+ source_image_click.select(
113
+ process_image_click,
114
+ inputs=[origin_image, point_prompt,
115
+ clicked_points, image_resolution],
116
+ outputs=[source_image_click, clicked_points, click_mask],
117
+ show_progress=True, queue=True
118
+ )
119
+ clear_button_click.click(
120
+ fn=lambda original_image: (original_image.copy(), [], None)
121
+ if original_image is not None else (None, [], None),
122
+ inputs=[origin_image],
123
+ outputs=[source_image_click, clicked_points, click_mask]
124
+ )
125
+ clear_button_image.click(
126
+ fn=lambda: (None, [], None, None, None),
127
+ inputs=[],
128
+ outputs=[source_image_click, clicked_points,
129
+ click_mask, result_gallery_init, result_text]
130
+ )
131
+ if examples is not None:
132
+ with gr.Row():
133
+ ex = gr.Examples(examples=examples, fn=process,
134
+ inputs=[a_prompt, n_prompt, scale],
135
+ outputs=[result_gallery_init],
136
+ cache_examples=False)
137
+ if WARNING_INFO is not None:
138
+ with gr.Row():
139
+ gr.Markdown(WARNING_INFO)
140
+ return demo
sam2edit_handsome.py CHANGED
@@ -1,87 +1,37 @@
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
 
2
  import gradio as gr
3
  from diffusers.utils import load_image
4
  from sam2edit_lora import EditAnythingLoraModel, config_dict
 
 
5
 
6
 
7
-
8
- def create_demo(process):
9
 
10
  examples = [
11
- ["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>", "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
12
- ["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>","(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",6],
 
 
13
  ]
14
 
15
  print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
 
 
 
 
 
16
  WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
17
  We are not responsible for possible risks using this model.
18
  Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
19
  '''
20
- block = gr.Blocks()
21
- with block as demo:
22
- with gr.Row():
23
- gr.Markdown(
24
- "## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything ")
25
- with gr.Row():
26
- with gr.Column():
27
- source_image = gr.Image(
28
- source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
29
- enable_all_generate = gr.Checkbox(
30
- label='Auto generation on all region.', value=False)
31
- prompt = gr.Textbox(
32
- label="Prompt (Text in the expected things of edited region)")
33
- enable_auto_prompt = gr.Checkbox(
34
- label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
35
- a_prompt = gr.Textbox(
36
- label="Added Prompt", value='best quality, extremely detailed')
37
- n_prompt = gr.Textbox(label="Negative Prompt",
38
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
39
- control_scale = gr.Slider(
40
- label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
41
- run_button = gr.Button(label="Run")
42
- num_samples = gr.Slider(
43
- label="Images", minimum=1, maximum=12, value=2, step=1)
44
- seed = gr.Slider(label="Seed", minimum=-1,
45
- maximum=2147483647, step=1, randomize=True)
46
- enable_tile = gr.Checkbox(
47
- label='Tile refinement for high resolution generation.', value=True)
48
- with gr.Accordion("Advanced options", open=False):
49
- mask_image = gr.Image(
50
- source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
51
- image_resolution = gr.Slider(
52
- label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
53
- strength = gr.Slider(
54
- label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
55
- guess_mode = gr.Checkbox(
56
- label='Guess Mode', value=False)
57
- detect_resolution = gr.Slider(
58
- label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
59
- ddim_steps = gr.Slider(
60
- label="Steps", minimum=1, maximum=100, value=30, step=1)
61
- scale = gr.Slider(
62
- label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
63
- eta = gr.Number(label="eta (DDIM)", value=0.0)
64
- with gr.Column():
65
- result_gallery = gr.Gallery(
66
- label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
67
- result_text = gr.Text(label='BLIP2+Human Prompt Text')
68
- ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
69
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
70
- run_button.click(fn=process, inputs=ips, outputs=[
71
- result_gallery, result_text])
72
- with gr.Row():
73
- ex = gr.Examples(examples=examples, fn=process,
74
- inputs=[a_prompt, n_prompt, scale],
75
- outputs=[result_gallery],
76
- cache_examples=False)
77
- with gr.Row():
78
- gr.Markdown(WARNING_INFO)
79
  return demo
80
 
81
 
82
-
83
  if __name__ == '__main__':
84
- model = EditAnythingLoraModel(base_model_path= '../../gradio-rel/EditAnything/models/Realistic_Vision_V2.0',
85
- lora_model_path= '../../gradio-rel/EditAnything/models/asianmale', use_blip=True)
86
- demo = create_demo(model.process)
87
  demo.queue().launch(server_name='0.0.0.0')
 
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
  from sam2edit_lora import EditAnythingLoraModel, config_dict
6
+ from sam2edit_demo import create_demo_template
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
+ def create_demo(process, process_image_click=None):
 
11
 
12
  examples = [
13
+ ["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>",
14
+ "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
15
+ ["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>",
16
+ "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
17
  ]
18
 
19
  print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
20
+
21
+ INFO = f'''
22
+ ## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything
23
+ This model is good at generating handsome male.
24
+ '''
25
  WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
26
  We are not responsible for possible risks using this model.
27
  Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
28
  '''
29
+ demo = create_demo_template(process, process_image_click, examples=examples, INFO=INFO, WARNING_INFO=WARNING_INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return demo
31
 
32
 
 
33
  if __name__ == '__main__':
34
+ model = EditAnythingLoraModel(base_model_path='Realistic_Vision_V2.0',
35
+ lora_model_path=None, use_blip=True)
36
+ demo = create_demo(model.process, model.process_image_click)
37
  demo.queue().launch(server_name='0.0.0.0')
sam2edit_lora.py CHANGED
@@ -14,7 +14,7 @@ import random
14
  import os
15
  import requests
16
  from io import BytesIO
17
- from annotator.util import resize_image, HWC3
18
 
19
  import torch
20
  from safetensors.torch import load_file
@@ -22,7 +22,6 @@ from collections import defaultdict
22
  from diffusers import StableDiffusionControlNetPipeline
23
  from diffusers import ControlNetModel, UniPCMultistepScheduler
24
  from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
25
- # from utils.tmp import StableDiffusionControlNetInpaintPipeline
26
  # need the latest transformers
27
  # pip install git+https://github.com/huggingface/transformers.git
28
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
@@ -32,13 +31,13 @@ import PIL.Image
32
  # Segment-Anything init.
33
  # pip install git+https://github.com/facebookresearch/segment-anything.git
34
  try:
35
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
36
  except ImportError:
37
  print('segment_anything not installed')
38
  result = subprocess.run(
39
  ['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
40
  print(f'Install segment_anything {result}')
41
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
42
  if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
43
  result = subprocess.run(
44
  ['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
@@ -52,13 +51,18 @@ config_dict = OrderedDict([
52
  ])
53
 
54
 
55
- def init_sam_model():
 
 
56
  sam_checkpoint = "models/sam_vit_h_4b8939.pth"
57
  model_type = "default"
58
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
59
  sam.to(device=device)
60
- sam_generator = SamAutomaticMaskGenerator(sam)
61
- return sam_generator
 
 
 
62
 
63
 
64
  def init_blip_processor():
@@ -112,7 +116,6 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
112
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
113
 
114
 
115
-
116
  def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
117
  LORA_PREFIX_UNET = "lora_unet"
118
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -241,10 +244,12 @@ def make_inpaint_condition(image, image_mask):
241
  image = torch.from_numpy(image)
242
  return image
243
 
 
244
  def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
245
  controlnet = []
246
- controlnet.append(ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)) # sam control
247
- if (not generation_only) and extra_inpaint: # inpainting control
 
248
  print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
249
  controlnet.append(
250
  ControlNetModel.from_pretrained(
@@ -271,17 +276,18 @@ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, g
271
  pipe.enable_model_cpu_offload()
272
  return pipe
273
 
 
274
  def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
275
  controlnet = ControlNetModel.from_pretrained(
276
- 'lllyasviel/control_v11f1e_sd15_tile', torch_dtype=torch.float16) # tile controlnet
277
- if base_model_path=='runwayml/stable-diffusion-v1-5' or base_model_path=='stabilityai/stable-diffusion-2-inpainting':
278
  print("base_model_path", base_model_path)
279
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
280
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
281
  )
282
  else:
283
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
284
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
285
  )
286
  if lora_model_path is not None:
287
  pipe = load_lora_weights(
@@ -296,7 +302,6 @@ def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
296
  return pipe
297
 
298
 
299
-
300
  def show_anns(anns):
301
  if len(anns) == 0:
302
  return
@@ -331,9 +336,11 @@ class EditAnythingLoraModel:
331
  blip_model=None,
332
  sam_generator=None,
333
  controlmodel_name='LAION Pretrained(v0-4)-SD15',
334
- extra_inpaint=True, # used when the base model is not an inpainting model.
 
335
  tile_model=None,
336
  lora_weight=1.0,
 
337
  ):
338
  self.device = device
339
  self.use_blip = use_blip
@@ -348,11 +355,8 @@ class EditAnythingLoraModel:
348
  base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
349
 
350
  # Segment-Anything init.
351
- if sam_generator is not None:
352
- self.sam_generator = sam_generator
353
- else:
354
- self.sam_generator = init_sam_model()
355
-
356
  # BLIP2 init.
357
  if use_blip:
358
  if blip_processor is not None:
@@ -369,7 +373,8 @@ class EditAnythingLoraModel:
369
  if tile_model is not None:
370
  self.tile_pipe = tile_model
371
  else:
372
- self.tile_pipe = obtain_tile_model(base_model_path, lora_model_path, lora_weight=lora_weight)
 
373
 
374
  def get_blip2_text(self, image):
375
  inputs = self.blip_processor(image, return_tensors="pt").to(
@@ -384,19 +389,92 @@ class EditAnythingLoraModel:
384
  full_img, res = show_anns(masks)
385
  return full_img, res
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  @torch.inference_mode()
388
- def process(self, source_image, enable_all_generate, mask_image,
389
- control_scale,
390
- enable_auto_prompt, prompt, a_prompt, n_prompt,
391
- num_samples, image_resolution, detect_resolution,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  ddim_steps, guess_mode, strength, scale, seed, eta,
393
- enable_tile=True, condition_model=None):
394
 
395
  if condition_model is None:
396
  this_controlnet_path = self.default_controlnet_path
397
  else:
398
  this_controlnet_path = config_dict[condition_model]
399
- input_image = source_image["image"]
 
400
  if mask_image is None:
401
  if enable_all_generate != self.defalut_enable_all_generate:
402
  self.pipe = obtain_generation_model(
@@ -410,6 +488,8 @@ class EditAnythingLoraModel:
410
  (input_image.shape[0], input_image.shape[1], 3))*255
411
  else:
412
  mask_image = source_image["mask"]
 
 
413
  if self.default_controlnet_path != this_controlnet_path:
414
  print("To Use:", this_controlnet_path,
415
  "Current:", self.default_controlnet_path)
@@ -424,10 +504,10 @@ class EditAnythingLoraModel:
424
  print("Generating text:")
425
  blip2_prompt = self.get_blip2_text(input_image)
426
  print("Generated text:", blip2_prompt)
427
- if len(prompt) > 0:
428
- prompt = blip2_prompt + ',' + prompt
429
  else:
430
- prompt = blip2_prompt
431
 
432
  input_image = HWC3(input_image)
433
 
@@ -448,23 +528,23 @@ class EditAnythingLoraModel:
448
  control = torch.stack([control for _ in range(num_samples)], dim=0)
449
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
450
 
451
- mask_image = HWC3(mask_image.astype(np.uint8))
452
  mask_image_tmp = cv2.resize(
453
- mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
454
  mask_image = Image.fromarray(mask_image_tmp)
455
 
456
  if seed == -1:
457
  seed = random.randint(0, 65535)
458
  seed_everything(seed)
459
  generator = torch.manual_seed(seed)
460
- postive_prompt = prompt + ', ' + a_prompt
461
  negative_prompt = n_prompt
462
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
463
  self.pipe, postive_prompt, negative_prompt, "cuda")
464
  prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
465
  negative_prompt_embeds = torch.cat(
466
  [negative_prompt_embeds] * num_samples, dim=0)
467
- if enable_all_generate and self.extra_inpaint:
468
  self.pipe.safety_checker = lambda images, clip_input: (
469
  images, False)
470
  x_samples = self.pipe(
@@ -485,7 +565,8 @@ class EditAnythingLoraModel:
485
  if self.extra_inpaint:
486
  inpaint_image = make_inpaint_condition(img, mask_image_tmp)
487
  print(inpaint_image.shape)
488
- multi_condition_image.append(inpaint_image.type(torch.float16))
 
489
  multi_condition_scale.append(1.0)
490
  x_samples = self.pipe(
491
  image=img,
@@ -501,33 +582,33 @@ class EditAnythingLoraModel:
501
  ).images
502
  results = [x_samples[i] for i in range(num_samples)]
503
 
504
- if True:
505
- img_tile = [PIL.Image.fromarray(resize_image(np.array(x_samples[i]), 768)) for i in range(num_samples)]
506
- # for each in img_tile:
507
- # print("tile",each.size)
508
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
509
  self.tile_pipe, postive_prompt, negative_prompt, "cuda")
510
- prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
511
- negative_prompt_embeds = torch.cat(
512
- [negative_prompt_embeds] * num_samples, dim=0)
513
- x_samples_tile = self.tile_pipe(
514
- prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
515
- num_images_per_prompt=num_samples,
516
- num_inference_steps=ddim_steps,
517
- generator=generator,
518
- height=img_tile[0].size[1],
519
- width=img_tile[0].size[0],
520
- image=img_tile,
521
- controlnet_conditioning_scale=1.0,
522
- ).images
523
-
524
- results_tile = [x_samples_tile[i] for i in range(num_samples)]
525
- results = results_tile + results
526
-
527
-
528
-
529
-
530
- return [full_segmask, mask_image] + results, prompt
 
 
531
 
532
  def download_image(url):
533
  response = requests.get(url)
 
14
  import os
15
  import requests
16
  from io import BytesIO
17
+ from annotator.util import resize_image, HWC3, resize_points
18
 
19
  import torch
20
  from safetensors.torch import load_file
 
22
  from diffusers import StableDiffusionControlNetPipeline
23
  from diffusers import ControlNetModel, UniPCMultistepScheduler
24
  from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
 
25
  # need the latest transformers
26
  # pip install git+https://github.com/huggingface/transformers.git
27
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
 
31
  # Segment-Anything init.
32
  # pip install git+https://github.com/facebookresearch/segment-anything.git
33
  try:
34
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
35
  except ImportError:
36
  print('segment_anything not installed')
37
  result = subprocess.run(
38
  ['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
39
  print(f'Install segment_anything {result}')
40
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
41
  if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
42
  result = subprocess.run(
43
  ['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
 
51
  ])
52
 
53
 
54
+ def init_sam_model(sam_generator=None, mask_predictor=None):
55
+ if sam_generator is not None and mask_predictor is not None:
56
+ return sam_generator, mask_predictor
57
  sam_checkpoint = "models/sam_vit_h_4b8939.pth"
58
  model_type = "default"
59
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
60
  sam.to(device=device)
61
+ sam_generator = SamAutomaticMaskGenerator(
62
+ sam) if sam_generator is None else sam_generator
63
+ mask_predictor = SamPredictor(
64
+ sam) if mask_predictor is None else mask_predictor
65
+ return sam_generator, mask_predictor
66
 
67
 
68
  def init_blip_processor():
 
116
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
117
 
118
 
 
119
  def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
120
  LORA_PREFIX_UNET = "lora_unet"
121
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
 
244
  image = torch.from_numpy(image)
245
  return image
246
 
247
+
248
  def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
249
  controlnet = []
250
+ controlnet.append(ControlNetModel.from_pretrained(
251
+ controlnet_path, torch_dtype=torch.float16)) # sam control
252
+ if (not generation_only) and extra_inpaint: # inpainting control
253
  print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
254
  controlnet.append(
255
  ControlNetModel.from_pretrained(
 
276
  pipe.enable_model_cpu_offload()
277
  return pipe
278
 
279
+
280
  def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
281
  controlnet = ControlNetModel.from_pretrained(
282
+ 'lllyasviel/control_v11f1e_sd15_tile', torch_dtype=torch.float16) # tile controlnet
283
+ if base_model_path == 'runwayml/stable-diffusion-v1-5' or base_model_path == 'stabilityai/stable-diffusion-2-inpainting':
284
  print("base_model_path", base_model_path)
285
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
286
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
287
  )
288
  else:
289
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
290
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
291
  )
292
  if lora_model_path is not None:
293
  pipe = load_lora_weights(
 
302
  return pipe
303
 
304
 
 
305
  def show_anns(anns):
306
  if len(anns) == 0:
307
  return
 
336
  blip_model=None,
337
  sam_generator=None,
338
  controlmodel_name='LAION Pretrained(v0-4)-SD15',
339
+ # used when the base model is not an inpainting model.
340
+ extra_inpaint=True,
341
  tile_model=None,
342
  lora_weight=1.0,
343
+ mask_predictor=None
344
  ):
345
  self.device = device
346
  self.use_blip = use_blip
 
355
  base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
356
 
357
  # Segment-Anything init.
358
+ self.sam_generator, self.mask_predictor = init_sam_model(
359
+ sam_generator, mask_predictor)
 
 
 
360
  # BLIP2 init.
361
  if use_blip:
362
  if blip_processor is not None:
 
373
  if tile_model is not None:
374
  self.tile_pipe = tile_model
375
  else:
376
+ self.tile_pipe = obtain_tile_model(
377
+ base_model_path, lora_model_path, lora_weight=lora_weight)
378
 
379
  def get_blip2_text(self, image):
380
  inputs = self.blip_processor(image, return_tensors="pt").to(
 
389
  full_img, res = show_anns(masks)
390
  return full_img, res
391
 
392
+ def get_click_mask(self, image, clicked_points):
393
+ self.mask_predictor.set_image(image)
394
+ # Separate the points and labels
395
+ points, labels = zip(*[(point[:2], point[2])
396
+ for point in clicked_points])
397
+
398
+ # Convert the points and labels to numpy arrays
399
+ input_point = np.array(points)
400
+ input_label = np.array(labels)
401
+
402
+ masks, _, _ = self.mask_predictor.predict(
403
+ point_coords=input_point,
404
+ point_labels=input_label,
405
+ multimask_output=False,
406
+ )
407
+
408
+ return masks
409
+
410
  @torch.inference_mode()
411
+ def process_image_click(self, original_image: gr.Image,
412
+ point_prompt: gr.Radio,
413
+ clicked_points: gr.State,
414
+ image_resolution,
415
+ evt: gr.SelectData):
416
+ # Get the clicked coordinates
417
+ clicked_coords = evt.index
418
+ x, y = clicked_coords
419
+ label = point_prompt
420
+ lab = 1 if label == "Foreground Point" else 0
421
+ clicked_points.append((x, y, lab))
422
+
423
+ input_image = np.array(original_image, dtype=np.uint8)
424
+ H, W, C = input_image.shape
425
+ input_image = HWC3(input_image)
426
+ img = resize_image(input_image, image_resolution)
427
+
428
+ # Update the clicked_points
429
+ resized_points = resize_points(clicked_points,
430
+ input_image.shape,
431
+ image_resolution)
432
+ mask_click_np = self.get_click_mask(img, resized_points)
433
+
434
+ # Convert mask_click_np to HWC format
435
+ mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
436
+
437
+ mask_image = HWC3(mask_click_np.astype(np.uint8))
438
+ mask_image = cv2.resize(
439
+ mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
440
+ # mask_image = Image.fromarray(mask_image_tmp)
441
+
442
+ # Draw circles for all clicked points
443
+ edited_image = input_image
444
+ for x, y, lab in clicked_points:
445
+ # Set the circle color based on the label
446
+ color = (255, 0, 0) if lab == 1 else (0, 0, 255)
447
+
448
+ # Draw the circle
449
+ edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)
450
+
451
+ # Set the opacity for the mask_image and edited_image
452
+ opacity_mask = 0.75
453
+ opacity_edited = 1.0
454
+
455
+ # Combine the edited_image and the mask_image using cv2.addWeighted()
456
+ overlay_image = cv2.addWeighted(
457
+ edited_image, opacity_edited,
458
+ (mask_image * np.array([0/255, 255/255, 0/255])).astype(np.uint8),
459
+ opacity_mask, 0
460
+ )
461
+
462
+ return Image.fromarray(overlay_image), clicked_points, Image.fromarray(mask_image)
463
+
464
+ @torch.inference_mode()
465
+ def process(self, source_image, enable_all_generate, mask_image,
466
+ control_scale,
467
+ enable_auto_prompt, a_prompt, n_prompt,
468
+ num_samples, image_resolution, detect_resolution,
469
  ddim_steps, guess_mode, strength, scale, seed, eta,
470
+ enable_tile=True, refine_alignment_ratio=None, refine_image_resolution=None, condition_model=None):
471
 
472
  if condition_model is None:
473
  this_controlnet_path = self.default_controlnet_path
474
  else:
475
  this_controlnet_path = config_dict[condition_model]
476
+ input_image = source_image["image"] if isinstance(
477
+ source_image, dict) else np.array(source_image, dtype=np.uint8)
478
  if mask_image is None:
479
  if enable_all_generate != self.defalut_enable_all_generate:
480
  self.pipe = obtain_generation_model(
 
488
  (input_image.shape[0], input_image.shape[1], 3))*255
489
  else:
490
  mask_image = source_image["mask"]
491
+ else:
492
+ mask_image = np.array(mask_image, dtype=np.uint8)
493
  if self.default_controlnet_path != this_controlnet_path:
494
  print("To Use:", this_controlnet_path,
495
  "Current:", self.default_controlnet_path)
 
504
  print("Generating text:")
505
  blip2_prompt = self.get_blip2_text(input_image)
506
  print("Generated text:", blip2_prompt)
507
+ if len(a_prompt) > 0:
508
+ a_prompt = blip2_prompt + ',' + a_prompt
509
  else:
510
+ a_prompt = blip2_prompt
511
 
512
  input_image = HWC3(input_image)
513
 
 
528
  control = torch.stack([control for _ in range(num_samples)], dim=0)
529
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
530
 
531
+ mask_imag_ori = HWC3(mask_image.astype(np.uint8))
532
  mask_image_tmp = cv2.resize(
533
+ mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR)
534
  mask_image = Image.fromarray(mask_image_tmp)
535
 
536
  if seed == -1:
537
  seed = random.randint(0, 65535)
538
  seed_everything(seed)
539
  generator = torch.manual_seed(seed)
540
+ postive_prompt = a_prompt
541
  negative_prompt = n_prompt
542
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
543
  self.pipe, postive_prompt, negative_prompt, "cuda")
544
  prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
545
  negative_prompt_embeds = torch.cat(
546
  [negative_prompt_embeds] * num_samples, dim=0)
547
+ if enable_all_generate and not self.extra_inpaint:
548
  self.pipe.safety_checker = lambda images, clip_input: (
549
  images, False)
550
  x_samples = self.pipe(
 
565
  if self.extra_inpaint:
566
  inpaint_image = make_inpaint_condition(img, mask_image_tmp)
567
  print(inpaint_image.shape)
568
+ multi_condition_image.append(
569
+ inpaint_image.type(torch.float16))
570
  multi_condition_scale.append(1.0)
571
  x_samples = self.pipe(
572
  image=img,
 
582
  ).images
583
  results = [x_samples[i] for i in range(num_samples)]
584
 
585
+ results_tile = []
586
+ if enable_tile:
 
 
587
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
588
  self.tile_pipe, postive_prompt, negative_prompt, "cuda")
589
+ for i in range(num_samples):
590
+ img_tile = PIL.Image.fromarray(resize_image(
591
+ np.array(x_samples[i]), refine_image_resolution))
592
+ if i == 0:
593
+ mask_image_tile = cv2.resize(
594
+ mask_imag_ori, (img_tile.size[0], img_tile.size[1]), interpolation=cv2.INTER_LINEAR)
595
+ mask_image_tile = Image.fromarray(mask_image_tile)
596
+ x_samples_tile = self.tile_pipe(
597
+ image=img_tile,
598
+ mask_image=mask_image_tile,
599
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
600
+ num_images_per_prompt=1,
601
+ num_inference_steps=ddim_steps,
602
+ generator=generator,
603
+ controlnet_conditioning_image=img_tile,
604
+ height=img_tile.size[1],
605
+ width=img_tile.size[0],
606
+ controlnet_conditioning_scale=1.0,
607
+ alignment_ratio=refine_alignment_ratio,
608
+ ).images
609
+ results_tile += x_samples_tile
610
+
611
+ return results_tile, results, [full_segmask, mask_image], postive_prompt
612
 
613
  def download_image(url):
614
  response = requests.get(url)
utils/stable_diffusion_controlnet_inpaint.py CHANGED
@@ -835,6 +835,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
835
  callback_steps: int = 1,
836
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
837
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
 
838
  ):
839
  r"""
840
  Function invoked when calling the pipeline for generation.
@@ -1115,12 +1116,15 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
1115
  progress_bar.update()
1116
  if callback is not None and i % callback_steps == 0:
1117
  callback(i, t, latents)
1118
- # if self.unet.config.in_channels==4:
1119
- # # masking for non-inpainting models
1120
- # init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
1121
- # latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
1122
 
1123
- if self.unet.config.in_channels==4:
 
 
 
 
 
 
 
1124
  # fill the unmasked part with original image
1125
  latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
1126
 
 
835
  callback_steps: int = 1,
836
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
837
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
838
+ alignment_ratio = None,
839
  ):
840
  r"""
841
  Function invoked when calling the pipeline for generation.
 
1116
  progress_bar.update()
1117
  if callback is not None and i % callback_steps == 0:
1118
  callback(i, t, latents)
 
 
 
 
1119
 
1120
+ if self.unet.config.in_channels==4 and alignment_ratio is not None:
1121
+ if i < len(timesteps) * alignment_ratio:
1122
+ # print(i, len(timesteps))
1123
+ # masking for non-inpainting models
1124
+ init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
1125
+ latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
1126
+
1127
+ if self.unet.config.in_channels==4 and (alignment_ratio==1.0 or alignment_ratio is None):
1128
  # fill the unmasked part with original image
1129
  latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
1130