hirol's picture
Upload 27 files
a637d5e
raw
history blame
5.66 kB
import os, glob, json, base64, re
from io import BytesIO
from PIL import Image, PngImagePlugin
from image_process import image_canny,image_pose_mask,image_pose_mask_numpy
from generate_img import generate_image, generate_image_sketch
_SAVED_POSES_DIR = ''
image_cache = dict()
def set_save_dir(dir: str):
global _SAVED_POSES_DIR
_SAVED_POSES_DIR = os.path.realpath(str(dir))
def get_save_dir():
assert len(_SAVED_POSES_DIR) != 0
return _SAVED_POSES_DIR
def get_saved_path(name: str):
#return os.path.realpath(os.path.join(get_save_dir(), name))
return os.path.join(get_save_dir(), name)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
def sorted_glob(path):
return sorted(glob.glob(path), key=natural_keys)
def name2path(name: str):
if not isinstance(name, str):
raise ValueError(f'str object expected, but {type(name)}')
if len(name) == 0:
raise ValueError(f'empty name')
if '.' in name or '/' in name or '\\' in name:
raise ValueError(f'invalid name: {name}')
path = get_saved_path(f'{name}.png')
if not path.startswith(get_save_dir()):
raise ValueError(f'invalid name: {name}')
return path
def saved_poses():
for path in sorted_glob(os.path.join(get_save_dir(), '*.png')):
yield Image.open(path)
def all_poses():
for img in saved_poses():
buffer = BytesIO()
img.save(buffer, format='png')
if not hasattr(img, 'text'):
continue
pose_dict = {
'name': img.text['name'], # type: ignore
'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
'screen': json.loads(img.text['screen']), # type: ignore
'camera': json.loads(img.text['camera']), # type: ignore
'joints': json.loads(img.text['joints']), # type: ignore
}
yield pose_dict
def save_pose(data: dict):
print(data)
name = data['name']
screen = data['screen']
camera = data['camera']
joints = data['joints']
info = PngImagePlugin.PngInfo()
info.add_text('name', name)
info.add_text('screen', json.dumps(screen))
info.add_text('camera', json.dumps(camera))
info.add_text('joints', json.dumps(joints))
filepath = name2path(name)
image = Image.open(BytesIO(base64.b64decode(data['image'][len('data:image/png;base64,'):])))
unit = max(image.width, image.height)
mx, my = (unit - image.width) // 2, (unit - image.height) // 2
canvas = Image.new('RGB', (unit, unit), color=(68, 68, 68))
canvas.paste(image, (mx, my))
image = canvas.resize((canvas.width//4, canvas.height//4))
image.save(filepath, pnginfo=info)
def delete_pose(name: str):
filepath = name2path(name)
os.remove(filepath)
def load_pose(name: str):
filepath = name2path(name)
img = Image.open(filepath)
buffer = BytesIO()
img.save(buffer, format='png')
if not hasattr(img, 'text'):
raise ValueError(f'not pose data: {filepath}')
pose_dict = {
'name': img.text['name'], # type: ignore
'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
'screen': json.loads(img.text['screen']), # type: ignore
'camera': json.loads(img.text['camera']), # type: ignore
'joints': json.loads(img.text['joints']), # type: ignore
}
return pose_dict
def base64_PIL(data:str):
return Image.open(BytesIO(base64.b64decode(data)))
def PIL_base64(data):
return base64.b64encode(data.tobytes()).decode('utf-8')
def resizeImg(image1,image2):
width1, height1 = image1.size
# 使用图像1的宽高来resize图像2
image2_resized = image2.resize((width1, height1))
# 返回resize后的图像2
return image2_resized
# def get_img(data):
# #执行逻辑
# if (data[0]):
# bgImgBase64 = data[0]['bgImg'][len('data:image/png;base64,'):]
# maskImgBase64 = data[0]['maskImg'][len('data:image/png;base64,'):]
# image_cache['bgImgBase64'] = bgImgBase64
# image_cache['maskImgBase64'] = maskImgBase64
# return 'success'
def generate_img(data, image_prompt, image_n_prompt):
if (data[0]):
bg_img = data[0]['bgImg'][len('data:image/png;base64,'):]
mask_img_openpose = data[0]['maskImg'][len('data:image/png;base64,'):]
print((len(bg_img), len(mask_img_openpose)))
print((image_prompt, image_n_prompt))
maskImg_base64 = image_pose_mask(mask_img_openpose)
controlnet_img_pil = base64_PIL(mask_img_openpose)
bg_img_pil = base64_PIL(bg_img)
mask_img_pil = base64_PIL(maskImg_base64)
bg_img_pil = resizeImg(mask_img_pil, bg_img_pil)
img = generate_image(image_prompt, image_n_prompt, controlnet_img_pil, bg_img_pil, mask_img_pil)
return [img]
# return [mask_img_pil]
#openpose流程
return None
def get_image_sketch(image, image_prompt, image_n_prompt):
img_origin_numpy = image['image']
img_sketch_numpy = image['mask']
# print(type(img_origin))
# print(type(PIL_base64(Image.fromarray(img_masj))))
mask_pil = base64_PIL(image_pose_mask_numpy(img_sketch_numpy))
img_origin_pil = Image.fromarray(img_origin_numpy)
sketch_pil = Image.fromarray(img_sketch_numpy)
img = generate_image_sketch(image_prompt, image_n_prompt, sketch_pil, img_origin_pil, mask_pil)
return img
# return [mask_pil,img_origin_pil,Image.fromarray(img_masj)]