File size: 3,679 Bytes
3f85c56
 
 
 
 
ab854ef
d538d26
286713d
3f85c56
52c565a
 
 
 
 
f7998c5
 
ab854ef
f7998c5
3f85c56
 
f7998c5
3f85c56
 
45748ff
3f85c56
 
 
 
b8883eb
3f85c56
 
 
b8883eb
3f85c56
52c565a
 
 
10cf594
 
 
 
 
52c565a
3f85c56
52c565a
3f85c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8883eb
4279b74
3f85c56
 
 
 
 
 
4eb9bf3
52c565a
3f85c56
 
 
52c565a
 
 
 
3f85c56
52c565a
 
 
 
3f85c56
b8883eb
52c565a
 
 
4eb9bf3
 
3f85c56
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
import gradio as gr
import time
import spaces

from segment_utils import(
    segment_image,
    restore_result,
)

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f'{device} is available')

model_id = "stabilityai/stable-diffusion-x4-upscaler"
upscale_pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
upscale_pipe = upscale_pipe.to(device)

DEFAULT_SRC_PROMPT = "a person with pefect face"
DEFAULT_CATEGORY = "face"


def create_demo() -> gr.Blocks:

    @spaces.GPU(duration=30)
    def upscale_image(
        input_image: Image,
        prompt: str,
        num_inference_steps: int = 10,
    ):
        time_cost_str = ''
        run_task_time = 0
        run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
        upscaled_image = upscale_pipe(
            prompt=prompt, 
            image=input_image,
            num_inference_steps=num_inference_steps,
        ).images[0]
        run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
        
        return upscaled_image, time_cost_str

    def get_time_cost(run_task_time, time_cost_str):
        now_time = int(time.time()*1000)
        if run_task_time == 0:
            time_cost_str = 'start'
        else:
            if time_cost_str != '': 
                time_cost_str += f'-->'
            time_cost_str += f'{now_time - run_task_time}'
        run_task_time = now_time
        return run_task_time, time_cost_str

    with gr.Blocks() as demo:
        croper = gr.State()
        with gr.Row():
            with gr.Column():
                input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
            with gr.Column():
                num_inference_steps = gr.Number(label="Num Inference Steps", value=5)
                generate_size = gr.Number(label="Generate Size", value=512)
                g_btn = gr.Button("Upscale Image")
                
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="pil")
            with gr.Column():
                restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
                origin_area_image = gr.Image(label="Origin Area Image", format="png", type="pil", interactive=False, visible=False)
                upscaled_image = gr.Image(label="Upscaled Image", format="png", type="pil", interactive=False)
                download_path = gr.File(label="Download the output image", interactive=False)
                generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
                category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
                mask_expansion = gr.Number(label="Mask Expansion", value=20, visible=False)
                mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation", visible=False)

        g_btn.click(
            fn=segment_image,
            inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
            outputs=[origin_area_image, croper],
        ).success(
            fn=upscale_image,
            inputs=[origin_area_image, input_image_prompt, num_inference_steps],
            outputs=[upscaled_image, generated_cost],
        ).success(
            fn=restore_result,
            inputs=[croper, category, upscaled_image],
            outputs=[restored_image, download_path],
        )

    return demo