Akjava's picture
change default value
217cf5d
import spaces
import gradio as gr
import subprocess
from PIL import Image
import json
import os
import time
import mp_box
import draw_landmarks68
'''
Face landmark detection based Face Detection.
https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker
from model card
https://storage.googleapis.com/mediapipe-assets/MediaPipe%20BlazeFace%20Model%20Card%20(Short%20Range).pdf
Licensed Apache License, Version 2.0
Train with google's dataset(more detail see model card)
'''
dir_name ="files"
passed_time = 60*60
def clear_old_files(dir,passed_time):
try:
files = os.listdir(dir)
current_time = time.time()
for file in files:
file_path = os.path.join(dir,file)
ctime = os.stat(file_path).st_ctime
diff = current_time - ctime
#print(f"ctime={ctime},current_time={current_time},passed_time={passed_time},diff={diff}")
if diff > passed_time:
os.remove(file_path)
except:
print("maybe still gallery using error")
def picker_color_to_rgba(picker_color):
color_value = picker_color.strip("rgba()").split(",")
color_value[0] = int(float(color_value[0]))
color_value[1] = int(float(color_value[1]))
color_value[2] = int(float(color_value[2]))
color_value[3] = int(float(color_value[3]))
return color_value
#@spaces.GPU(duration=120)
def process_images(image,draw_number,font_scale,text_color_text,dot_size,dot_color_text,line_size,line_color_text,box_size,box_color_text,json_format,draw_mesh=False,progress=gr.Progress(track_tqdm=True)):
if not os.path.exists(dir_name):
os.mkdir(dir_name)
clear_old_files(dir_name,passed_time)
if image == None:
raise gr.Error("Need Image")
progress(0, desc="Start Mediapipe")
boxes,mp_image,face_landmarker_result = mp_box.mediapipe_to_box(image)
#need only result
text_color = picker_color_to_rgba(text_color_text)
line_color = picker_color_to_rgba(line_color_text)
dot_color = picker_color_to_rgba(dot_color_text)
box_color = picker_color_to_rgba(box_color_text)
if draw_mesh:
image=Image.fromarray(mp_box.draw_landmarks_on_image(face_landmarker_result,image))
annotated_image,bbox,landmark_points = draw_landmarks68.draw_landmarks_on_image(image,face_landmarker_result,draw_number,font_scale,text_color,
dot_size,dot_color,line_size,line_color,
box_size,box_color)
if json_format=="raw":
jsons = landmark_points
else:
jsons=draw_landmarks68.convert_to_landmark_group_json(landmark_points)
#print(annotation_boxes)
formatted_json = json.dumps(jsons)
json_path=create_json_download(formatted_json)
#return image
return annotated_image,jsons,json_path
def write_file(file_path,text):
with open(file_path, 'w', encoding='utf-8') as f:
f.write(text)
def read_file(file_path):
"""read the text of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css="""
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap:10px
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
"""
#css=css,
import hashlib
def text_to_sha256(text):
text_bytes = text.encode('utf-8')
hash_object = hashlib.sha256()
hash_object.update(text_bytes)
sha256_hex = hash_object.hexdigest()
return sha256_hex
def create_json_download(text):
file_id = f"{dir_name}/landmark_{text_to_sha256(text)[:32]}.json"
write_file(file_id,text)
# try to save
return file_id
with gr.Blocks(css=css, elem_id="demo-container") as demo:
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB',elem_id="image_upload", type="pil", label="Upload")
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
btn = gr.Button("Extract Landmark 68", elem_id="run_button",variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row( equal_height=True):
draw_number = gr.Checkbox(label="draw Number",value=True)
font_scale = gr.Slider(
label="Font Scale",
minimum=0.1,
maximum=2,
step=0.1,
value=0.5)
text_color = gr.ColorPicker(value="rgba(200,200,200,1)",label="text color")
#square_shape = gr.Checkbox(label="Square shape")
with gr.Row( equal_height=True):
line_color = gr.ColorPicker(value="rgba(0,0,255,1)",label="line color")
line_size = gr.Slider(
label="Line Size",
minimum=0,
maximum=20,
step=1,
value=1)
with gr.Row( equal_height=True):
dot_color = gr.ColorPicker(value="rgba(255,0,0,1)",label="dot color")
dot_size = gr.Slider(
label="Dot Size",
minimum=0,
maximum=40,
step=1,
value=3)
with gr.Row( equal_height=True):
box_color = gr.ColorPicker(value="rgba(200,200,200,1)",label="box color")
box_size = gr.Slider(
label="Box Size",
minimum=0,
maximum=20,
step=1,
value=1)
with gr.Row( equal_height=True):
json_format = gr.Radio(choices=["raw","face-detection"],value="face-detection",label="json-output format")
draw_mesh = gr.Checkbox(value=True,label="draw mesh",info="draw mediapipe mesh")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img")
text_out = gr.TextArea(label="JSON-Output")
download_button = gr.DownloadButton(label="Download JSON" )
#download_button.click(fn=json_download,inputs=text_out,outputs=download_button)
btn.click(fn=process_images, inputs=[image,draw_number,font_scale,text_color,
dot_size,dot_color,line_size,line_color,
box_size,box_color,json_format,draw_mesh], outputs =[image_out,text_out,download_button], api_name='infer')
gr.Examples(
examples =["examples/00003245_00.jpg","examples/00004200.jpg","examples/00005259.jpg","examples/00018022.jpg","examples/img-above.jpg","examples/img-below.jpg","examples/img-side.jpg"],
inputs=[image]
)
gr.HTML(read_file("demo_footer.html"))
if __name__ == "__main__":
demo.launch()