File size: 4,934 Bytes
9734b1e
 
 
 
 
 
 
aada7c5
 
 
 
 
 
 
 
9734b1e
 
 
 
 
 
 
 
aada7c5
9734b1e
aada7c5
 
9734b1e
 
aada7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9734b1e
 
aada7c5
9734b1e
aada7c5
 
 
 
 
 
9734b1e
 
 
 
 
 
 
aada7c5
9734b1e
aada7c5
9734b1e
aada7c5
9734b1e
 
aada7c5
 
9734b1e
aada7c5
9734b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aada7c5
9734b1e
aada7c5
 
 
9734b1e
aada7c5
9734b1e
 
 
 
 
 
 
 
 
 
 
 
aada7c5
9734b1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from enum import Enum
import gc
import numpy as np
import torch




import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel


import utils
import gradio_utils
import os

from einops import rearrange

import matplotlib.pyplot as plt

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

class Model:
    def __init__(self, **kwargs):
        self.base_controlnet, self.base_controlnet_params = FlaxControlNetModel.from_pretrained(
       #"JFoz/dog-cat-pose", dtype=jnp.bfloat16
        "lllyasviel/control_v11p_sd15_openpose", dtype=jnp.bfloat16, from_pt=True
        )
        self.pipe, self.params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", controlnet=self.base_controlnet, revision="flax", dtype=jnp.bfloat16,# from_pt=True,
        )

    def infer_frame(self, frame_id, prompt, negative_prompt, rng, **kwargs):

        print(prompt, frame_id)

        num_samples = 1
        prompt_ids = self.pipe.prepare_text_inputs([prompt[frame_id]]*num_samples)
        negative_prompt_ids = self.pipe.prepare_text_inputs([negative_prompt[frame_id]] * num_samples)
        processed_image = self.pipe.prepare_image_inputs([kwargs['image'][frame_id]]*num_samples)
    
        self.params["controlnet"] = self.base_controlnet_params


        p_params = replicate(self.params)
        prompt_ids = shard(prompt_ids)
        negative_prompt_ids = shard(negative_prompt_ids)
        processed_image = shard(processed_image)
    
        output = self.pipe(
            prompt_ids=prompt_ids,
            image=processed_image,
            params=p_params,
            prng_seed=rng,
            num_inference_steps=50,
            neg_prompt_ids=negative_prompt_ids,
            jit=True,
        ).images

        output_images = np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
        return output_images

    def inference(self, **kwargs):
        
        seed = kwargs.pop('seed', 0)
       
        rng = create_key(0)
        rng = jax.random.split(rng, jax.device_count())   

        f = len(kwargs['image'])
        print('frames', f)


        assert 'prompt' in kwargs
        prompt = [kwargs.pop('prompt')] * f
        negative_prompt = [kwargs.pop('negative_prompt', '')] * f

        frames_counter = 0
        
        result = []
        for i in range(0, f):
            print(f'Processing frame {i + 1} / {f}')
            result.append(self.infer_frame(frame_id=i,
                                                   prompt=prompt,
                                                   negative_prompt=negative_prompt,
                                                   rng = rng,
                                                   **kwargs))
            frames_counter += 1
        result = np.stack(result, axis=0)
        return result

    def process_controlnet_pose(self,
                                video_path,
                                prompt,
                                num_inference_steps=20,
                                controlnet_conditioning_scale=1.0,
                                guidance_scale=9.0,
                                seed=42,
                                eta=0.0,
                                resolution=512,
                                save_path=None):
        print("Module Pose")
        video_path = gradio_utils.motion_to_video_path(video_path)


        added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
        negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'

        video, fps = utils.prepare_video(
            video_path, resolution, False, output_fps=4)
        control = utils.pre_process_pose(
            video, apply_pose_detect=False)
        
        print('N frames', len(control))
        f, _, h, w = video.shape

        result = self.inference(image=control,
                                prompt=prompt + ', ' + added_prompt,
                                height=h,
                                width=w,
                                negative_prompt=negative_prompts,
                                num_inference_steps=num_inference_steps,
                                guidance_scale=guidance_scale,
                                controlnet_conditioning_scale=controlnet_conditioning_scale,
                                eta=eta,
                                seed=seed,
                                output_type='numpy',
                                )
        return utils.create_gif(result.astype(jnp.float16), fps, path=save_path)