File size: 4,156 Bytes
2fe55e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import io
import math

from PIL import Image, ImageDraw

from lama_cleaner.helper import load_img
from lama_cleaner.plugins.base_plugin import BasePlugin


def keep_ratio_resize(img, size, resample=Image.BILINEAR):
    if img.width > img.height:
        w = size
        h = int(img.height * size / img.width)
    else:
        h = size
        w = int(img.width * size / img.height)
    return img.resize((w, h), resample)


def cubic_bezier(p1, p2, duration: int, frames: int):
    """

    Args:
        p1:
        p2:
        duration: Total duration of the curve
        frames:

    Returns:

    """
    x0, y0 = (0, 0)
    x1, y1 = p1
    x2, y2 = p2
    x3, y3 = (1, 1)

    def cal_y(t):
        return (
            math.pow(1 - t, 3) * y0
            + 3 * math.pow(1 - t, 2) * t * y1
            + 3 * (1 - t) * math.pow(t, 2) * y2
            + math.pow(t, 3) * y3
        )

    def cal_x(t):
        return (
            math.pow(1 - t, 3) * x0
            + 3 * math.pow(1 - t, 2) * t * x1
            + 3 * (1 - t) * math.pow(t, 2) * x2
            + math.pow(t, 3) * x3
        )

    res = []
    for t in range(0, 1 * frames, duration):
        t = t / frames
        res.append((cal_x(t), cal_y(t)))

    res.append((1, 0))
    return res


def make_compare_gif(
    clean_img: Image.Image,
    src_img: Image.Image,
    max_side_length: int = 600,
    splitter_width: int = 5,
    splitter_color=(255, 203, 0, int(255 * 0.73)),
):
    if clean_img.size != src_img.size:
        clean_img = clean_img.resize(src_img.size, Image.BILINEAR)

    duration_per_frame = 20
    num_frames = 50
    # erase-in-out
    cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames)
    cubic_bezier_points.reverse()

    max_side_length = min(max_side_length, max(clean_img.size))

    src_img = keep_ratio_resize(src_img, max_side_length)
    clean_img = keep_ratio_resize(clean_img, max_side_length)
    width, height = src_img.size

    # Generate images to make Gif from right to left
    images = []

    for i in range(num_frames):
        new_frame = Image.new("RGB", (width, height))
        new_frame.paste(clean_img, (0, 0))

        left = int(cubic_bezier_points[i][0] * width)
        cropped_src_img = src_img.crop((left, 0, width, height))
        new_frame.paste(cropped_src_img, (left, 0, width, height))
        if i != num_frames - 1:
            # draw a yellow splitter on the edge of the cropped image
            draw = ImageDraw.Draw(new_frame)
            draw.line(
                [(left, 0), (left, height)], width=splitter_width, fill=splitter_color
            )
        images.append(new_frame)

    for i in range(30):
        images.append(src_img)

    cubic_bezier_points.reverse()
    # Generate images to make Gif from left to right
    for i in range(num_frames):
        new_frame = Image.new("RGB", (width, height))
        new_frame.paste(src_img, (0, 0))

        right = int(cubic_bezier_points[i][0] * width)
        cropped_src_img = clean_img.crop((0, 0, right, height))
        new_frame.paste(cropped_src_img, (0, 0, right, height))
        if i != num_frames - 1:
            # draw a yellow splitter on the edge of the cropped image
            draw = ImageDraw.Draw(new_frame)
            draw.line(
                [(right, 0), (right, height)], width=splitter_width, fill=splitter_color
            )
        images.append(new_frame)

    for _ in range(30):
        images.append(clean_img)

    img_byte_arr = io.BytesIO()
    clean_img.save(
        img_byte_arr,
        format="GIF",
        save_all=True,
        include_color_table=True,
        append_images=images,
        optimize=False,
        duration=duration_per_frame,
        loop=0,
    )
    return img_byte_arr.getvalue()


class MakeGIF(BasePlugin):
    name = "MakeGIF"

    def __call__(self, rgb_np_img, files, form):
        origin_image = rgb_np_img
        clean_image_bytes = files["clean_img"].read()
        clean_image, _ = load_img(clean_image_bytes)
        gif_bytes = make_compare_gif(
            Image.fromarray(origin_image), Image.fromarray(clean_image)
        )
        return gif_bytes