Akjava's picture
iniit
58665c8
raw
history blame
8.95 kB
import spaces
import gradio as gr
import subprocess
from PIL import Image
import json
import os
import time
import mp_box
import draw_landmarks68
import landmarks68_utils
import io
import numpy as np
'''
Face landmark detection based Face Detection.
https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker
from model card
https://storage.googleapis.com/mediapipe-assets/MediaPipe%20BlazeFace%20Model%20Card%20(Short%20Range).pdf
Licensed Apache License, Version 2.0
Train with google's dataset(more detail see model card)
'''
dir_name ="files"
passed_time = 60*60
def clear_old_files(dir,passed_time):
try:
files = os.listdir(dir)
current_time = time.time()
for file in files:
file_path = os.path.join(dir,file)
ctime = os.stat(file_path).st_ctime
diff = current_time - ctime
#print(f"ctime={ctime},current_time={current_time},passed_time={passed_time},diff={diff}")
if diff > passed_time:
os.remove(file_path)
except:
print("maybe still gallery using error")
def get_image_id(image,length=32):
buffer = io.BytesIO()
image.save(buffer, format='PNG')
hash_object = hashlib.sha256(buffer.getvalue())
hex_dig = hash_object.hexdigest()
unique_id = hex_dig[:length]
return unique_id
def save_image(image,extension="jpg"):
id = get_image_id(image)
os.makedirs(dir_name,exist_ok=True)
file_path = f"{dir_name}/{id}.{extension}"
image.save(file_path)
return file_path
def picker_color_to_rgba(picker_color):
color_value = picker_color.strip("rgba()").split(",")
color_value[0] = int(float(color_value[0]))
color_value[1] = int(float(color_value[1]))
color_value[2] = int(float(color_value[2]))
color_value[3] = int(float(color_value[3]))
return color_value
#@spaces.GPU(duration=120)
def process_images(image,progress=gr.Progress(track_tqdm=True)):
if image == None:
raise gr.Error("Need Image")
progress(0, desc="Start Mediapipe")
boxes,mp_image,face_landmarker_result = mp_box.mediapipe_to_box(image)
annotated_image,bbox,landmark_points = draw_landmarks68.draw_landmarks_on_image(image,face_landmarker_result)
landmark_list = draw_landmarks68.convert_to_landmark_group_json(landmark_points)
annotations = []
galleries = []
def append(mask,label):
file_path = save_image(mask)
galleries.append((file_path,label))
annotations.append((np.array(mask.convert("1")),label))
def fill_points(points,base_image=None):
if base_image == None:
base_image = landmarks68_utils.create_color_image(image.width,image.height,(0,0,0))
landmarks68_utils.fill_points(base_image,points)
return base_image
# TODO support type
left_eye_points = landmarks68_utils.get_landmark_points(landmark_list,landmarks68_utils.PARTS_LEFT_EYE)
right_eye_points = landmarks68_utils.get_landmark_points(landmark_list,landmarks68_utils.PARTS_RIGHT_EYE)
eyes_mask = fill_points(left_eye_points)
eyes_mask = fill_points(right_eye_points,eyes_mask)
append(eyes_mask,"eyes")
upper_lip_points = landmarks68_utils.get_landmark_points(landmark_list,landmarks68_utils.PARTS_UPPER_LIP)
upper_lip_mask = fill_points(upper_lip_points)
append(upper_lip_mask,"upper-lip")
lower_lip_points = landmarks68_utils.get_landmark_points(landmark_list,landmarks68_utils.PARTS_LOWER_LIP)
lower_lip_mask = fill_points(lower_lip_points)
append(lower_lip_mask,"lower-lip")
inner_mouth_points = landmarks68_utils.get_innner_mouth_points(landmark_list)
inner_mouth_mask = fill_points(inner_mouth_points)
append(inner_mouth_mask,"inner-mouth")
# TODO support type
contour_points = landmarks68_utils.get_landmark_points(landmark_list,landmarks68_utils.PARTS_CONTOUR)
contour_points=landmarks68_utils.get_face_points(landmark_list)
contour_mask = fill_points(contour_points)
append(contour_mask,"contour")
mixed = Image.composite(eyes_mask,upper_lip_mask,eyes_mask.convert("L"))
mixed = Image.composite(mixed,lower_lip_mask,mixed.convert("L"))
mixed = Image.composite(mixed,inner_mouth_mask,mixed.convert("L"))
append(mixed,"mixed")
return [image,annotations],galleries
def write_file(file_path,text):
with open(file_path, 'w', encoding='utf-8') as f:
f.write(text)
def read_file(file_path):
"""read the text of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css="""
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap:10px
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
"""
#css=css,
import hashlib
def text_to_sha256(text):
text_bytes = text.encode('utf-8')
hash_object = hashlib.sha256()
hash_object.update(text_bytes)
sha256_hex = hash_object.hexdigest()
return sha256_hex
def create_json_download(text):
file_id = f"{dir_name}/landmark_{text_to_sha256(text)[:32]}.json"
write_file(file_id,text)
# try to save
return file_id
with gr.Blocks(css=css, elem_id="demo-container") as demo:
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB',elem_id="image_upload", type="pil", label="Upload")
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
btn = gr.Button("Create Landmark 68 Mask", elem_id="run_button",variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row( equal_height=True):
draw_number = gr.Checkbox(label="draw Number")
font_scale = gr.Slider(
label="Font Scale",
minimum=0.1,
maximum=2,
step=0.1,
value=0.5)
text_color = gr.ColorPicker(value="rgba(200,200,200,1)",label="text color")
#square_shape = gr.Checkbox(label="Square shape")
with gr.Row( equal_height=True):
line_color = gr.ColorPicker(value="rgba(0,0,255,1)",label="line color")
line_size = gr.Slider(
label="Line Size",
minimum=0,
maximum=20,
step=1,
value=1)
with gr.Row( equal_height=True):
dot_color = gr.ColorPicker(value="rgba(255,0,0,1)",label="dot color")
dot_size = gr.Slider(
label="Dot Size",
minimum=0,
maximum=40,
step=1,
value=3)
with gr.Row( equal_height=True):
box_color = gr.ColorPicker(value="rgba(200,200,200,1)",label="box color")
box_size = gr.Slider(
label="Box Size",
minimum=0,
maximum=20,
step=1,
value=1)
with gr.Row( equal_height=True):
json_format = gr.Radio(choices=["raw","face-detection"],value="face-detection",label="json-output format")
with gr.Column():
image_out = gr.AnnotatedImage(label="Output", elem_id="output-img")
image_gallery = gr.Gallery(label="masks",preview=True)
#download_button.click(fn=json_download,inputs=text_out,outputs=download_button)
btn.click(fn=process_images, inputs=[image],outputs=[image_out,image_gallery] ,api_name='infer')
gr.Examples(
examples =["examples/00003245_00.jpg","examples/00004200.jpg","examples/00002200.jpg","examples/00005259.jpg","examples/00018022.jpg","examples/img-above.jpg","examples/img-below.jpg","examples/img-side.jpg"],
inputs=[image]
)
gr.HTML(read_file("demo_footer.html"))
if __name__ == "__main__":
demo.launch()