File size: 6,730 Bytes
3d3e65e
23faa2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3e65e
 
 
 
 
 
 
23faa2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3e65e
23faa2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3e65e
23faa2e
3d3e65e
 
23faa2e
 
3d3e65e
23faa2e
 
 
 
 
3d3e65e
 
23faa2e
 
 
3d3e65e
23faa2e
 
 
 
 
 
 
 
 
3d3e65e
 
23faa2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import hashlib
import os
from io import BytesIO

import gradio as gr
import grpc
from PIL import Image
from cachetools import LRUCache

from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face


def get_bytes(img):
    if img is None:
        return img

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return buffered.getvalue()


def bytes_to_image(image: bytes) -> Image.Image:
    image = Image.open(BytesIO(image))
    return image


def center_crop(img):
    width, height = img.size
    side = min(width, height)

    left = (width - side) / 2
    top = (height - side) / 2
    right = (width + side) / 2
    bottom = (height + side) / 2

    img = img.crop((left, top, right, bottom))
    return img


def resize(name):
    def resize_inner(img, align):
        global align_cache

        if name in align:
            img_hash = hashlib.md5(get_bytes(img)).hexdigest()

            if img_hash not in align_cache:
                img = align_face(img, return_tensors=False)[0]
                align_cache[img_hash] = img
            else:
                img = align_cache[img_hash]

        elif img.size != (1024, 1024):
            img = center_crop(img)
            img = img.resize((1024, 1024), Image.Resampling.LANCZOS)

        return img

    return resize_inner


def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
    if not face and not shape and not color:
        return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)
    elif not face:
        return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True)
    elif not shape and not color:
        return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True)

    face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))

    if shape_bytes is None:
        shape_bytes = b'face'
    if color_bytes is None:
        color_bytes = b'shape'

    with grpc.insecure_channel(os.environ['SERVER']) as channel:
        stub = HairSwapServiceStub(channel)

        output: HairSwapResponse = stub.swap(
            HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
                            poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
        )

    output = bytes_to_image(output.image)
    return gr.update(value=output, visible=True), gr.update(visible=False)


def get_demo():
    with gr.Blocks() as demo:
        gr.Markdown("## HairFastGan")
        gr.Markdown(
            '<div style="display: flex; align-items: center; gap: 10px;">'
            '<span>Official HairFastGAN Gradio demo:</span>'
            '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
            '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
            '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
            '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
            '</div>'
        )
        with gr.Row():
            with gr.Column():
                source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
                with gr.Row():
                    shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
                    color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
                with gr.Accordion("Advanced Options", open=False):
                    blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
                                        label="Color Encoder version", info="Selects a model for hair color transfer.")
                    poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
                                              info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
                    poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
                                                info="Smooths out the blending area.")
                    align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
                                             label="Image cropping [recommended]",
                                             info="Selects which images to crop by face")
                btn = gr.Button("Get the haircut")
            with gr.Column():
                output = gr.Image(label="Your result")
                error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")

        gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
                              ["input/10.jpg", None, "input/11.jpg"]],
                    inputs=[source, shape, color], outputs=output)

        source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
        shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
        color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)

        btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
                  outputs=[output, error_message])

        gr.Markdown('''To cite the paper by the authors

    ```

        @article{nikolaev2024hairfastgan,

          title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},

          author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},

          journal={arXiv preprint arXiv:2404.01094},

          year={2024}

        }

    ```

        ''')
    return demo


if __name__ == '__main__':
    align_cache = LRUCache(maxsize=10)
    demo = get_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860)