JohannesBerends commited on
Commit
7f173a3
1 Parent(s): 737bdf0

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,38 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
=2.0.0 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: grpcio in c:\users\administrator\appdata\local\programs\python\python310\lib\site-packages (1.63.0)
2
+ Collecting grpcio
3
+ Using cached grpcio-1.64.1-cp310-cp310-win_amd64.whl.metadata (3.4 kB)
4
+ Using cached grpcio-1.64.1-cp310-cp310-win_amd64.whl (4.1 MB)
5
+ Installing collected packages: grpcio
6
+ Attempting uninstall: grpcio
7
+ Found existing installation: grpcio 1.63.0
8
+ Uninstalling grpcio-1.63.0:
9
+ Successfully uninstalled grpcio-1.63.0
10
+ Successfully installed grpcio-1.64.1
README.md CHANGED
@@ -1,12 +1,17 @@
1
- ---
2
- title: HairFastGAN
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
+ ---
2
+ title: HairFastGAN
3
+ emoji: 💈
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.31.5
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ custom_headers:
12
+ cross-origin-embedder-policy: require-corp
13
+ cross-origin-opener-policy: same-origin
14
+ cross-origin-resource-policy: cross-origin
15
+ ---
16
+
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/inference_pb2.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
__pycache__/inference_pb2.cpython-39.pyc ADDED
Binary file (1.2 kB). View file
 
__pycache__/inference_pb2_grpc.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
__pycache__/inference_pb2_grpc.cpython-39.pyc ADDED
Binary file (3.22 kB). View file
 
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from io import BytesIO
4
+
5
+ import gradio as gr
6
+ import grpc
7
+ from PIL import Image
8
+ from cachetools import LRUCache
9
+
10
+ from inference_pb2 import HairSwapRequest, HairSwapResponse
11
+ from inference_pb2_grpc import HairSwapServiceStub
12
+ from utils.shape_predictor import align_face
13
+
14
+
15
+ def get_bytes(img):
16
+ if img is None:
17
+ return img
18
+
19
+ buffered = BytesIO()
20
+ img.save(buffered, format="JPEG")
21
+ return buffered.getvalue()
22
+
23
+
24
+ def bytes_to_image(image: bytes) -> Image.Image:
25
+ image = Image.open(BytesIO(image))
26
+ return image
27
+
28
+
29
+ def center_crop(img):
30
+ width, height = img.size
31
+ side = min(width, height)
32
+
33
+ left = (width - side) / 2
34
+ top = (height - side) / 2
35
+ right = (width + side) / 2
36
+ bottom = (height + side) / 2
37
+
38
+ img = img.crop((left, top, right, bottom))
39
+ return img
40
+
41
+
42
+ def resize(name):
43
+ def resize_inner(img, align):
44
+ global align_cache
45
+
46
+ if name in align:
47
+ img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
+
49
+ if img_hash not in align_cache:
50
+ img = align_face(img, return_tensors=False)[0]
51
+ align_cache[img_hash] = img
52
+ else:
53
+ img = align_cache[img_hash]
54
+
55
+ elif img.size != (1024, 1024):
56
+ img = center_crop(img)
57
+ img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
+
59
+ return img
60
+
61
+ return resize_inner
62
+
63
+
64
+ def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
65
+ if not face and not shape and not color:
66
+ return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)
67
+ elif not face:
68
+ return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True)
69
+ elif not shape and not color:
70
+ return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True)
71
+
72
+ face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))
73
+
74
+ if shape_bytes is None:
75
+ shape_bytes = b'face'
76
+ if color_bytes is None:
77
+ color_bytes = b'shape'
78
+ if os.environ.get('https_proxy'):
79
+ del os.environ['https_proxy']
80
+ if os.environ.get('http_proxy'):
81
+ del os.environ['http_proxy']
82
+ os.environ['SERVER'] = '172.16.4.26:7860'
83
+
84
+ # with grpc.insecure_channel(os.environ['SERVER'], options=(('grpc.enable_http_proxy', 0),)) as channel:
85
+ # stub = HairSwapServiceStub(channel)
86
+
87
+ # output: HairSwapResponse = stub.swap(
88
+ # HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
89
+ # poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
90
+ # )
91
+
92
+ # output = bytes_to_image(output.image)
93
+ # return gr.update(value=output, visible=True), gr.update(visible=False)
94
+
95
+
96
+ def get_demo():
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("## HairFastGan")
99
+ gr.Markdown(
100
+ '<div style="display: flex; align-items: center; gap: 10px;">'
101
+ '<span>Official HairFastGAN Gradio demo:</span>'
102
+ '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
103
+ '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
104
+ '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
105
+ '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
106
+ '</div>'
107
+ )
108
+ with gr.Row():
109
+ with gr.Column():
110
+ source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
111
+ with gr.Row():
112
+ shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
113
+ color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
114
+ with gr.Accordion("Advanced Options", open=False):
115
+ blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
116
+ label="Color Encoder version", info="Selects a model for hair color transfer.")
117
+ poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
118
+ info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
119
+ poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
120
+ info="Smooths out the blending area.")
121
+ align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
122
+ label="Image cropping [recommended]",
123
+ info="Selects which images to crop by face")
124
+ btn = gr.Button("Get the haircut")
125
+ with gr.Column():
126
+ output = gr.Image(label="Your result")
127
+ error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
128
+
129
+ gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
130
+ ["input/10.jpg", None, "input/11.jpg"]],
131
+ inputs=[source, shape, color], outputs=output)
132
+
133
+ source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
134
+ shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
135
+ color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
136
+
137
+ btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
138
+ outputs=[output, error_message])
139
+
140
+ gr.Markdown('''To cite the paper by the authors
141
+ ```
142
+ @article{nikolaev2024hairfastgan,
143
+ title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
144
+ author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
145
+ journal={arXiv preprint arXiv:2404.01094},
146
+ year={2024}
147
+ }
148
+ ```
149
+ ''')
150
+ return demo
151
+
152
+
153
+ if __name__ == '__main__':
154
+ align_cache = LRUCache(maxsize=10)
155
+ demo = get_demo()
156
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
inference_pb2.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: inference.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"\x92\x01\n\x0fHairSwapRequest\x12\x0c\n\x04\x66\x61\x63\x65\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x01(\x0c\x12\r\n\x05\x63olor\x18\x03 \x01(\x0c\x12\x10\n\x08\x62lending\x18\x04 \x01(\t\x12\x15\n\rpoisson_iters\x18\x05 \x01(\x05\x12\x17\n\x0fpoisson_erosion\x18\x06 \x01(\x05\x12\x11\n\tuse_cache\x18\x07 \x01(\x08\"!\n\x10HairSwapResponse\x12\r\n\x05image\x18\x01 \x01(\x0c\x32R\n\x0fHairSwapService\x12?\n\x04swap\x12\x1a.inference.HairSwapRequest\x1a\x1b.inference.HairSwapResponseb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_HAIRSWAPREQUEST']._serialized_start=31
25
+ _globals['_HAIRSWAPREQUEST']._serialized_end=177
26
+ _globals['_HAIRSWAPRESPONSE']._serialized_start=179
27
+ _globals['_HAIRSWAPRESPONSE']._serialized_end=212
28
+ _globals['_HAIRSWAPSERVICE']._serialized_start=214
29
+ _globals['_HAIRSWAPSERVICE']._serialized_end=296
30
+ # @@protoc_insertion_point(module_scope)
inference_pb2.pyi ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf import descriptor as _descriptor
2
+ from google.protobuf import message as _message
3
+ from typing import ClassVar as _ClassVar, Optional as _Optional
4
+
5
+ DESCRIPTOR: _descriptor.FileDescriptor
6
+
7
+ class HairSwapRequest(_message.Message):
8
+ __slots__ = ("face", "shape", "color", "blending", "poisson_iters", "poisson_erosion", "use_cache")
9
+ FACE_FIELD_NUMBER: _ClassVar[int]
10
+ SHAPE_FIELD_NUMBER: _ClassVar[int]
11
+ COLOR_FIELD_NUMBER: _ClassVar[int]
12
+ BLENDING_FIELD_NUMBER: _ClassVar[int]
13
+ POISSON_ITERS_FIELD_NUMBER: _ClassVar[int]
14
+ POISSON_EROSION_FIELD_NUMBER: _ClassVar[int]
15
+ USE_CACHE_FIELD_NUMBER: _ClassVar[int]
16
+ face: bytes
17
+ shape: bytes
18
+ color: bytes
19
+ blending: str
20
+ poisson_iters: int
21
+ poisson_erosion: int
22
+ use_cache: bool
23
+ def __init__(self, face: _Optional[bytes] = ..., shape: _Optional[bytes] = ..., color: _Optional[bytes] = ..., blending: _Optional[str] = ..., poisson_iters: _Optional[int] = ..., poisson_erosion: _Optional[int] = ..., use_cache: bool = ...) -> None: ...
24
+
25
+ class HairSwapResponse(_message.Message):
26
+ __slots__ = ("image",)
27
+ IMAGE_FIELD_NUMBER: _ClassVar[int]
28
+ image: bytes
29
+ def __init__(self, image: _Optional[bytes] = ...) -> None: ...
inference_pb2_grpc.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import inference_pb2 as inference__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.63.0'
9
+ GRPC_VERSION = grpc.__version__
10
+ EXPECTED_ERROR_RELEASE = '1.65.0'
11
+ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
12
+ _version_not_supported = False
13
+
14
+ try:
15
+ from grpc._utilities import first_version_is_lower
16
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
17
+ except ImportError:
18
+ _version_not_supported = True
19
+
20
+ if _version_not_supported:
21
+ warnings.warn(
22
+ f'The grpc package installed is at version {GRPC_VERSION},'
23
+ + f' but the generated code in inference_pb2_grpc.py depends on'
24
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
25
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
26
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
27
+ + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
28
+ + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
29
+ RuntimeWarning
30
+ )
31
+
32
+
33
+ class HairSwapServiceStub(object):
34
+ """Missing associated documentation comment in .proto file."""
35
+
36
+ def __init__(self, channel):
37
+ """Constructor.
38
+
39
+ Args:
40
+ channel: A grpc.Channel.
41
+ """
42
+ self.swap = channel.unary_unary(
43
+ '/inference.HairSwapService/swap',
44
+ request_serializer=inference__pb2.HairSwapRequest.SerializeToString,
45
+ response_deserializer=inference__pb2.HairSwapResponse.FromString,
46
+ _registered_method=True)
47
+
48
+
49
+ class HairSwapServiceServicer(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ def swap(self, request, context):
53
+ """Missing associated documentation comment in .proto file."""
54
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
55
+ context.set_details('Method not implemented!')
56
+ raise NotImplementedError('Method not implemented!')
57
+
58
+
59
+ def add_HairSwapServiceServicer_to_server(servicer, server):
60
+ rpc_method_handlers = {
61
+ 'swap': grpc.unary_unary_rpc_method_handler(
62
+ servicer.swap,
63
+ request_deserializer=inference__pb2.HairSwapRequest.FromString,
64
+ response_serializer=inference__pb2.HairSwapResponse.SerializeToString,
65
+ ),
66
+ }
67
+ generic_handler = grpc.method_handlers_generic_handler(
68
+ 'inference.HairSwapService', rpc_method_handlers)
69
+ server.add_generic_rpc_handlers((generic_handler,))
70
+
71
+
72
+ # This class is part of an EXPERIMENTAL API.
73
+ class HairSwapService(object):
74
+ """Missing associated documentation comment in .proto file."""
75
+
76
+ @staticmethod
77
+ def swap(request,
78
+ target,
79
+ options=(),
80
+ channel_credentials=None,
81
+ call_credentials=None,
82
+ insecure=False,
83
+ compression=None,
84
+ wait_for_ready=None,
85
+ timeout=None,
86
+ metadata=None):
87
+ return grpc.experimental.unary_unary(
88
+ request,
89
+ target,
90
+ '/inference.HairSwapService/swap',
91
+ inference__pb2.HairSwapRequest.SerializeToString,
92
+ inference__pb2.HairSwapResponse.FromString,
93
+ options,
94
+ channel_credentials,
95
+ insecure,
96
+ call_credentials,
97
+ compression,
98
+ wait_for_ready,
99
+ timeout,
100
+ metadata,
101
+ _registered_method=True)
input/0.png ADDED

Git LFS Details

  • SHA256: 2250590e65e153c785683218c0e2da0c21ae104ac2190857988e05474ca89986
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB
input/1.png ADDED

Git LFS Details

  • SHA256: 5f67d4e98519ee4c1b0dad362bacd95dc7f0c090b1c45ebfcee74a85c660e372
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
input/10.jpg ADDED
input/11.jpg ADDED
input/2.png ADDED

Git LFS Details

  • SHA256: 97e0975d499216762645987955e80ee8b764c062d8b47da8f432bf11004a6443
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
input/3.jpg ADDED
input/4.jpg ADDED
input/5.jpg ADDED
input/6.png ADDED

Git LFS Details

  • SHA256: 91bc9e71396e0e364f66b44d5c1d58d1e5036e53f019a5badffed99d04d7e413
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
input/7.png ADDED

Git LFS Details

  • SHA256: 5b126e1e7858c7d73dd1a0d2b24ca1f178d93040583b4474b1458cf70482ecc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
input/8.png ADDED

Git LFS Details

  • SHA256: 4b36754e56b501fa74e92a89fd218f4c8571975a722b0138e2897fb1a46f9790
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
input/9.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow==10.0.0
2
+ face_alignment==1.3.4
3
+ addict==2.4.0
4
+ git+https://github.com/openai/CLIP.git
5
+ gdown==3.12.2
6
+ grpcio==1.63.0
7
+ grpcio_tools==1.63.0
8
+ gradio==4.31.5
9
+ cachetools==5.3.3
10
+ dlib==19.24.1
shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
3
+ size 99693937
test.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Provide a default value if SERVER is not set
4
+ server_address = os.environ.get('SERVER', '127.0.0.1')
5
+ print(f"Server address: {os.environ['SERVER']}")
utils/__pycache__/drive.cpython-310.pyc ADDED
Binary file (3.58 kB). View file
 
utils/__pycache__/drive.cpython-39.pyc ADDED
Binary file (3.6 kB). View file
 
utils/__pycache__/shape_predictor.cpython-310.pyc ADDED
Binary file (5.79 kB). View file
 
utils/__pycache__/shape_predictor.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
utils/drive.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # URL helpers, see https://github.com/NVlabs/stylegan
2
+ # ------------------------------------------------------------------------------------------
3
+
4
+ import requests
5
+ import html
6
+ import hashlib
7
+ import gdown
8
+ import glob
9
+ import os
10
+ import io
11
+ from typing import Any
12
+ import re
13
+ import uuid
14
+
15
+ weight_dic = {'afhqwild.pt': 'https://drive.google.com/file/d/14OnzO4QWaAytKXVqcfWo_o2MzoR4ygnr/view?usp=sharing',
16
+ 'afhqdog.pt': 'https://drive.google.com/file/d/16v6jPtKVlvq8rg2Sdi3-R9qZEVDgvvEA/view?usp=sharing',
17
+ 'afhqcat.pt': 'https://drive.google.com/file/d/1HXLER5R3EMI8DSYDBZafoqpX4EtyOf2R/view?usp=sharing',
18
+ 'ffhq.pt': 'https://drive.google.com/file/d/1AT6bNR2ppK8f2ETL_evT27f3R_oyWNHS/view?usp=sharing',
19
+ 'metfaces.pt': 'https://drive.google.com/file/d/16wM2PwVWzaMsRgPExvRGsq6BWw_muKbf/view?usp=sharing',
20
+ 'seg.pth': 'https://drive.google.com/file/d/1lIKvQaFKHT5zC7uS4p17O9ZpfwmwlS62/view?usp=sharing'}
21
+
22
+
23
+ def download_weight(weight_path):
24
+ gdown.download(weight_dic[os.path.basename(weight_path)],
25
+ output=weight_path, fuzzy=True)
26
+
27
+
28
+ def is_url(obj: Any) -> bool:
29
+ """Determine whether the given object is a valid URL string."""
30
+ if not isinstance(obj, str) or not "://" in obj:
31
+ return False
32
+ try:
33
+ res = requests.compat.urlparse(obj)
34
+ if not res.scheme or not res.netloc or not "." in res.netloc:
35
+ return False
36
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
37
+ if not res.scheme or not res.netloc or not "." in res.netloc:
38
+ return False
39
+ except:
40
+ return False
41
+ return True
42
+
43
+
44
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True,
45
+ return_path: bool = False) -> Any:
46
+ """Download the given URL and return a binary-mode file object to access the data."""
47
+ assert is_url(url)
48
+ assert num_attempts >= 1
49
+
50
+ # Lookup from cache.
51
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
52
+ if cache_dir is not None:
53
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
54
+ if len(cache_files) == 1:
55
+ if (return_path):
56
+ return cache_files[0]
57
+ else:
58
+ return open(cache_files[0], "rb")
59
+
60
+ # Download.
61
+ url_name = None
62
+ url_data = None
63
+ with requests.Session() as session:
64
+ if verbose:
65
+ print("Downloading %s ..." % url, end="", flush=True)
66
+ for attempts_left in reversed(range(num_attempts)):
67
+ try:
68
+ with session.get(url) as res:
69
+ res.raise_for_status()
70
+ if len(res.content) == 0:
71
+ raise IOError("No data received")
72
+
73
+ if len(res.content) < 8192:
74
+ content_str = res.content.decode("utf-8")
75
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
76
+ links = [html.unescape(link) for link in content_str.split('"') if
77
+ "export=download" in link]
78
+ if len(links) == 1:
79
+ url = requests.compat.urljoin(url, links[0])
80
+ raise IOError("Google Drive virus checker nag")
81
+ if "Google Drive - Quota exceeded" in content_str:
82
+ raise IOError("Google Drive quota exceeded")
83
+
84
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
85
+ url_name = match[1] if match else url
86
+ url_data = res.content
87
+ if verbose:
88
+ print(" done")
89
+ break
90
+ except:
91
+ if not attempts_left:
92
+ if verbose:
93
+ print(" failed")
94
+ raise
95
+ if verbose:
96
+ print(".", end="", flush=True)
97
+
98
+ # Save to cache.
99
+ if cache_dir is not None:
100
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
101
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
102
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
103
+ os.makedirs(cache_dir, exist_ok=True)
104
+ with open(temp_file, "wb") as f:
105
+ f.write(url_data)
106
+ os.replace(temp_file, cache_file) # atomic
107
+ if (return_path): return cache_file
108
+
109
+ # Return data as file object.
110
+ return io.BytesIO(url_data)
utils/shape_predictor.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import PIL
5
+ import dlib
6
+ import numpy as np
7
+ import scipy
8
+ import scipy.ndimage
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision import transforms as T
12
+
13
+ from utils.drive import open_url
14
+
15
+ """
16
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
17
+ author: lzhbrian (https://lzhbrian.me)
18
+ date: 2020.1.5
19
+ note: code is heavily borrowed from
20
+ https://github.com/NVlabs/ffhq-dataset
21
+ http://dlib.net/face_landmark_detection.py.html
22
+
23
+ requirements:
24
+ apt install cmake
25
+ conda install Pillow numpy scipy
26
+ pip install dlib
27
+ # download face landmark model from:
28
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
29
+ """
30
+
31
+
32
+ def get_landmark(filepath, predictor):
33
+ """get landmark with dlib
34
+ :return: np.array shape=(68, 2)
35
+ """
36
+ detector = dlib.get_frontal_face_detector()
37
+
38
+ img = dlib.load_rgb_image(filepath)
39
+ dets = detector(img, 1)
40
+ filepath = Path(filepath)
41
+ print(f"{filepath.name}: Number of faces detected: {len(dets)}")
42
+ shapes = [predictor(img, d) for k, d in enumerate(dets)]
43
+
44
+ lms = [np.array([[tt.x, tt.y] for tt in shape.parts()]) for shape in shapes]
45
+
46
+ return lms
47
+
48
+
49
+ def get_landmark_from_tensors(tensors: list[torch.Tensor | Image.Image | np.ndarray], predictor):
50
+ detector = dlib.get_frontal_face_detector()
51
+ transform = T.ToPILImage()
52
+ images = []
53
+ lms = []
54
+
55
+ for k, tensor in enumerate(tensors):
56
+ if isinstance(tensor, torch.Tensor):
57
+ img_pil = transform(tensor)
58
+ else:
59
+ img_pil = tensor
60
+ img = np.array(img_pil)
61
+ images.append(img_pil)
62
+
63
+ dets = detector(img, 1)
64
+ if len(dets) == 0:
65
+ raise ValueError(f"No faces detected in the image {k}.")
66
+ elif len(dets) == 1:
67
+ print(f"Number of faces detected: {len(dets)}")
68
+ else:
69
+ print(f"Number of faces detected: {len(dets)}, get largest face")
70
+
71
+ # Find the largest face
72
+ dets = sorted(dets, key=lambda det: det.width() * det.height(), reverse=True)
73
+ shape = predictor(img, dets[0])
74
+ lm = np.array([[tt.x, tt.y] for tt in shape.parts()])
75
+ lms.append(lm)
76
+
77
+ return images, lms
78
+
79
+
80
+ def align_face(data, predictor=None, is_filepath=False, return_tensors=True):
81
+ """
82
+ :param data: filepath or list torch Tensors
83
+ :return: list of PIL Images
84
+ """
85
+ if predictor is None:
86
+ predictor_path = 'shape_predictor_68_face_landmarks.dat'
87
+
88
+ if not os.path.isfile(predictor_path):
89
+ print("Downloading Shape Predictor")
90
+ data_io = open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx")
91
+ with open(predictor_path, 'wb') as f:
92
+ f.write(data_io.getbuffer())
93
+
94
+ predictor = dlib.shape_predictor(predictor_path)
95
+
96
+ if is_filepath:
97
+ lms = get_landmark(data, predictor)
98
+ else:
99
+ if not isinstance(data, list):
100
+ data = [data]
101
+ images, lms = get_landmark_from_tensors(data, predictor)
102
+
103
+ imgs = []
104
+ for num_img, lm in enumerate(lms):
105
+ lm_chin = lm[0: 17] # left-right
106
+ lm_eyebrow_left = lm[17: 22] # left-right
107
+ lm_eyebrow_right = lm[22: 27] # left-right
108
+ lm_nose = lm[27: 31] # top-down
109
+ lm_nostrils = lm[31: 36] # top-down
110
+ lm_eye_left = lm[36: 42] # left-clockwise
111
+ lm_eye_right = lm[42: 48] # left-clockwise
112
+ lm_mouth_outer = lm[48: 60] # left-clockwise
113
+ lm_mouth_inner = lm[60: 68] # left-clockwise
114
+
115
+ # Calculate auxiliary vectors.
116
+ eye_left = np.mean(lm_eye_left, axis=0)
117
+ eye_right = np.mean(lm_eye_right, axis=0)
118
+ eye_avg = (eye_left + eye_right) * 0.5
119
+ eye_to_eye = eye_right - eye_left
120
+ mouth_left = lm_mouth_outer[0]
121
+ mouth_right = lm_mouth_outer[6]
122
+ mouth_avg = (mouth_left + mouth_right) * 0.5
123
+ eye_to_mouth = mouth_avg - eye_avg
124
+
125
+ # Choose oriented crop rectangle.
126
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
127
+ x /= np.hypot(*x)
128
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
129
+ y = np.flipud(x) * [-1, 1]
130
+ c = eye_avg + eye_to_mouth * 0.1
131
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
132
+ qsize = np.hypot(*x) * 2
133
+
134
+ # read image
135
+ if is_filepath:
136
+ img = PIL.Image.open(data)
137
+ else:
138
+ img = images[num_img]
139
+
140
+ output_size = 1024
141
+ # output_size = 256
142
+ transform_size = 4096
143
+ enable_padding = True
144
+
145
+ # Shrink.
146
+ shrink = int(np.floor(qsize / output_size * 0.5))
147
+ if shrink > 1:
148
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
149
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
150
+ quad /= shrink
151
+ qsize /= shrink
152
+
153
+ # Crop.
154
+ border = max(int(np.rint(qsize * 0.1)), 3)
155
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
156
+ int(np.ceil(max(quad[:, 1]))))
157
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
158
+ min(crop[3] + border, img.size[1]))
159
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
160
+ img = img.crop(crop)
161
+ quad -= crop[0:2]
162
+
163
+ # Pad.
164
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
165
+ int(np.ceil(max(quad[:, 1]))))
166
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
167
+ max(pad[3] - img.size[1] + border, 0))
168
+ if enable_padding and max(pad) > border - 4:
169
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
170
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
171
+ h, w, _ = img.shape
172
+ y, x, _ = np.ogrid[:h, :w, :1]
173
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
174
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
175
+ blur = qsize * 0.02
176
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
177
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
178
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
179
+ quad += pad[:2]
180
+
181
+ # Transform.
182
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(),
183
+ PIL.Image.BILINEAR)
184
+ if output_size < transform_size:
185
+ img = img.resize((output_size, output_size), PIL.Image.LANCZOS)
186
+
187
+ # Save aligned image.
188
+ imgs.append(img)
189
+
190
+ if return_tensors:
191
+ transform = T.ToTensor()
192
+ tensors = [transform(img).clamp(0, 1) for img in imgs]
193
+ return tensors
194
+ return imgs