Spaces:
Runtime error
Runtime error
File size: 5,664 Bytes
a637d5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os, glob, json, base64, re
from io import BytesIO
from PIL import Image, PngImagePlugin
from image_process import image_canny,image_pose_mask,image_pose_mask_numpy
from generate_img import generate_image, generate_image_sketch
_SAVED_POSES_DIR = ''
image_cache = dict()
def set_save_dir(dir: str):
global _SAVED_POSES_DIR
_SAVED_POSES_DIR = os.path.realpath(str(dir))
def get_save_dir():
assert len(_SAVED_POSES_DIR) != 0
return _SAVED_POSES_DIR
def get_saved_path(name: str):
#return os.path.realpath(os.path.join(get_save_dir(), name))
return os.path.join(get_save_dir(), name)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
def sorted_glob(path):
return sorted(glob.glob(path), key=natural_keys)
def name2path(name: str):
if not isinstance(name, str):
raise ValueError(f'str object expected, but {type(name)}')
if len(name) == 0:
raise ValueError(f'empty name')
if '.' in name or '/' in name or '\\' in name:
raise ValueError(f'invalid name: {name}')
path = get_saved_path(f'{name}.png')
if not path.startswith(get_save_dir()):
raise ValueError(f'invalid name: {name}')
return path
def saved_poses():
for path in sorted_glob(os.path.join(get_save_dir(), '*.png')):
yield Image.open(path)
def all_poses():
for img in saved_poses():
buffer = BytesIO()
img.save(buffer, format='png')
if not hasattr(img, 'text'):
continue
pose_dict = {
'name': img.text['name'], # type: ignore
'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
'screen': json.loads(img.text['screen']), # type: ignore
'camera': json.loads(img.text['camera']), # type: ignore
'joints': json.loads(img.text['joints']), # type: ignore
}
yield pose_dict
def save_pose(data: dict):
print(data)
name = data['name']
screen = data['screen']
camera = data['camera']
joints = data['joints']
info = PngImagePlugin.PngInfo()
info.add_text('name', name)
info.add_text('screen', json.dumps(screen))
info.add_text('camera', json.dumps(camera))
info.add_text('joints', json.dumps(joints))
filepath = name2path(name)
image = Image.open(BytesIO(base64.b64decode(data['image'][len('data:image/png;base64,'):])))
unit = max(image.width, image.height)
mx, my = (unit - image.width) // 2, (unit - image.height) // 2
canvas = Image.new('RGB', (unit, unit), color=(68, 68, 68))
canvas.paste(image, (mx, my))
image = canvas.resize((canvas.width//4, canvas.height//4))
image.save(filepath, pnginfo=info)
def delete_pose(name: str):
filepath = name2path(name)
os.remove(filepath)
def load_pose(name: str):
filepath = name2path(name)
img = Image.open(filepath)
buffer = BytesIO()
img.save(buffer, format='png')
if not hasattr(img, 'text'):
raise ValueError(f'not pose data: {filepath}')
pose_dict = {
'name': img.text['name'], # type: ignore
'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
'screen': json.loads(img.text['screen']), # type: ignore
'camera': json.loads(img.text['camera']), # type: ignore
'joints': json.loads(img.text['joints']), # type: ignore
}
return pose_dict
def base64_PIL(data:str):
return Image.open(BytesIO(base64.b64decode(data)))
def PIL_base64(data):
return base64.b64encode(data.tobytes()).decode('utf-8')
def resizeImg(image1,image2):
width1, height1 = image1.size
# 使用图像1的宽高来resize图像2
image2_resized = image2.resize((width1, height1))
# 返回resize后的图像2
return image2_resized
# def get_img(data):
# #执行逻辑
# if (data[0]):
# bgImgBase64 = data[0]['bgImg'][len('data:image/png;base64,'):]
# maskImgBase64 = data[0]['maskImg'][len('data:image/png;base64,'):]
# image_cache['bgImgBase64'] = bgImgBase64
# image_cache['maskImgBase64'] = maskImgBase64
# return 'success'
def generate_img(data, image_prompt, image_n_prompt):
if (data[0]):
bg_img = data[0]['bgImg'][len('data:image/png;base64,'):]
mask_img_openpose = data[0]['maskImg'][len('data:image/png;base64,'):]
print((len(bg_img), len(mask_img_openpose)))
print((image_prompt, image_n_prompt))
maskImg_base64 = image_pose_mask(mask_img_openpose)
controlnet_img_pil = base64_PIL(mask_img_openpose)
bg_img_pil = base64_PIL(bg_img)
mask_img_pil = base64_PIL(maskImg_base64)
bg_img_pil = resizeImg(mask_img_pil, bg_img_pil)
img = generate_image(image_prompt, image_n_prompt, controlnet_img_pil, bg_img_pil, mask_img_pil)
return [img]
# return [mask_img_pil]
#openpose流程
return None
def get_image_sketch(image, image_prompt, image_n_prompt):
img_origin_numpy = image['image']
img_sketch_numpy = image['mask']
# print(type(img_origin))
# print(type(PIL_base64(Image.fromarray(img_masj))))
mask_pil = base64_PIL(image_pose_mask_numpy(img_sketch_numpy))
img_origin_pil = Image.fromarray(img_origin_numpy)
sketch_pil = Image.fromarray(img_sketch_numpy)
img = generate_image_sketch(image_prompt, image_n_prompt, sketch_pil, img_origin_pil, mask_pil)
return img
# return [mask_pil,img_origin_pil,Image.fromarray(img_masj)] |