Spaces:
Runtime error
Runtime error
imsuperkong
commited on
Commit
•
55ef7fd
1
Parent(s):
2abfd01
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import sd.gradio_utils as gradio_utils
|
6 |
+
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import argparse
|
10 |
+
import ipdb
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
from tqdm import tqdm
|
14 |
+
from diffusers import DDIMScheduler
|
15 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
16 |
+
|
17 |
+
from sd.core import DDIMBackward, DDPM_forward
|
18 |
+
|
19 |
+
torch.backends.cudnn.enabled = True
|
20 |
+
torch.backends.cudnn.benchmark = True
|
21 |
+
|
22 |
+
def slerp(R_target, rotation_speed):
|
23 |
+
# Compute the angle of rotation from the rotation matrix
|
24 |
+
angle = np.arccos((np.trace(R_target) - 1) / 2)
|
25 |
+
|
26 |
+
# Handle the case where angle is very small (no significant rotation)
|
27 |
+
if angle < 1e-6:
|
28 |
+
return np.eye(3)
|
29 |
+
|
30 |
+
# Normalize the angle based on rotation_speed
|
31 |
+
normalized_angle = angle * rotation_speed
|
32 |
+
|
33 |
+
# Axis of rotation
|
34 |
+
axis = np.array([R_target[2, 1] - R_target[1, 2],
|
35 |
+
R_target[0, 2] - R_target[2, 0],
|
36 |
+
R_target[1, 0] - R_target[0, 1]])
|
37 |
+
axis = axis / np.linalg.norm(axis)
|
38 |
+
|
39 |
+
# Return the interpolated rotation matrix
|
40 |
+
return cv2.Rodrigues(axis * normalized_angle)[0]
|
41 |
+
|
42 |
+
|
43 |
+
def compute_extrinsic_parameters(clicked_point, depth, intrinsic_matrix, rotation_speed, step_x=0, step_y=0, step_z=0):
|
44 |
+
# Normalize the clicked point
|
45 |
+
x,y = clicked_point
|
46 |
+
x = int(x)
|
47 |
+
y = int(y)
|
48 |
+
x_normalized = (x - intrinsic_matrix[0, 2]) / intrinsic_matrix[0, 0]
|
49 |
+
y_normalized = (y - intrinsic_matrix[1, 2]) / intrinsic_matrix[1, 1]
|
50 |
+
|
51 |
+
# Depth at the clicked point
|
52 |
+
try:
|
53 |
+
z = depth[y, x]
|
54 |
+
except Exception:
|
55 |
+
ipdb.set_trace()
|
56 |
+
|
57 |
+
# Direction vector in camera coordinates
|
58 |
+
direction_vector = np.array([x_normalized * z, y_normalized * z, z])
|
59 |
+
|
60 |
+
# Calculate rotation angles to bring the clicked point to the center
|
61 |
+
angle_y = -np.arctan2(direction_vector[1], direction_vector[2]) # Rotation about Y-axis
|
62 |
+
angle_x = np.arctan2(direction_vector[0], direction_vector[2]) # Rotation about X-axis
|
63 |
+
|
64 |
+
# Apply rotation speed
|
65 |
+
angle_y *= rotation_speed
|
66 |
+
angle_x *= rotation_speed
|
67 |
+
|
68 |
+
# Compute rotation matrices
|
69 |
+
R_x = cv2.Rodrigues(np.array([1, 0, 0]) * angle_x)[0]
|
70 |
+
R_y = cv2.Rodrigues(np.array([0, 1, 0]) * angle_y)[0]
|
71 |
+
R = R_y @ R_x
|
72 |
+
|
73 |
+
# Compute rotation matrix to align direction vector with principal axis
|
74 |
+
T = np.array([step_x, -step_y, -step_z])
|
75 |
+
|
76 |
+
# Create extrinsic matrix
|
77 |
+
extrinsic_matrix = np.eye(4)
|
78 |
+
extrinsic_matrix[:3, :3] = R
|
79 |
+
extrinsic_matrix[:3, 3] = T
|
80 |
+
|
81 |
+
return extrinsic_matrix
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def encode_imgs(imgs):
|
85 |
+
imgs = 2 * imgs - 1
|
86 |
+
posterior = pipe.vae.encode(imgs).latent_dist
|
87 |
+
latents = posterior.mean * 0.18215
|
88 |
+
return latents
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def decode_latents(latents):
|
92 |
+
latents = 1 / 0.18215 * latents
|
93 |
+
imgs = pipe.vae.decode(latents).sample
|
94 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
95 |
+
return imgs
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def ddim_inversion(latent, cond, stop_t=1000, start_t=-1):
|
99 |
+
timesteps = reversed(pipe.scheduler.timesteps)
|
100 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
101 |
+
for i, t in enumerate(tqdm(timesteps)):
|
102 |
+
if t >= stop_t:
|
103 |
+
break
|
104 |
+
if t <=start_t:
|
105 |
+
continue
|
106 |
+
cond_batch = cond.repeat(latent.shape[0], 1, 1)
|
107 |
+
|
108 |
+
alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
|
109 |
+
alpha_prod_t_prev = (
|
110 |
+
pipe.scheduler.alphas_cumprod[timesteps[i - 1]]
|
111 |
+
if i > 0 else pipe.scheduler.final_alpha_cumprod
|
112 |
+
)
|
113 |
+
|
114 |
+
mu = alpha_prod_t ** 0.5
|
115 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
116 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
117 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
118 |
+
|
119 |
+
eps = pipe.unet(latent, t, encoder_hidden_states=cond_batch).sample
|
120 |
+
|
121 |
+
pred_x0 = (latent - sigma_prev * eps) / mu_prev
|
122 |
+
latent = mu * pred_x0 + sigma * eps
|
123 |
+
|
124 |
+
return latent
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def get_text_embeds(prompt, negative_prompt='', batch_size=1):
|
128 |
+
text_input = pipe.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
|
129 |
+
text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
|
130 |
+
|
131 |
+
uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
|
132 |
+
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
|
133 |
+
|
134 |
+
# cat for final embeddings
|
135 |
+
text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size).to(torch_dtype)
|
136 |
+
return text_embeddings
|
137 |
+
|
138 |
+
def save_video(frames, fps=10, out_path='output/output.mp4'):
|
139 |
+
video_dims = (512, 512)
|
140 |
+
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
141 |
+
video = cv2.VideoWriter(out_path,fourcc, fps, video_dims)
|
142 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
143 |
+
for frame in frames:
|
144 |
+
video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
|
145 |
+
video.release()
|
146 |
+
|
147 |
+
def draw_prompt(prompt):
|
148 |
+
return prompt
|
149 |
+
|
150 |
+
def to_image(tensor):
|
151 |
+
tensor = tensor.squeeze(0).permute(1, 2, 0)
|
152 |
+
arr = tensor.detach().cpu().numpy()
|
153 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min())
|
154 |
+
arr = arr * 255
|
155 |
+
return arr.astype('uint8')
|
156 |
+
|
157 |
+
def add_points_to_image(image, points):
|
158 |
+
image = gradio_utils.draw_handle_target_points(image, points, 5)
|
159 |
+
return image
|
160 |
+
|
161 |
+
|
162 |
+
def on_click(state, seed, count, prompt, neg_prompt, speed_r, speed_x, speed_y, speed_z, t1, t2, t3, lr, guidance_weight,attn,threshold, early_stop, evt: gr.SelectData):
|
163 |
+
end_id = int(t1)
|
164 |
+
start_id=int(t2)
|
165 |
+
startstart_id = int(t3)
|
166 |
+
timesteps = reversed(ddim_scheduler.timesteps)
|
167 |
+
end_t = timesteps[end_id]
|
168 |
+
start_t = timesteps[start_id]
|
169 |
+
startstart_t = timesteps[startstart_id]
|
170 |
+
attn=float(attn)
|
171 |
+
cfg_norm=False
|
172 |
+
cfg_decay=False
|
173 |
+
guidance_loss_scale = float(guidance_weight)
|
174 |
+
lr = float(lr)
|
175 |
+
threshold = int(threshold)
|
176 |
+
up_ft_indexes = 2
|
177 |
+
early_stop = int(early_stop)
|
178 |
+
generator = torch.Generator(device).manual_seed(int(seed)) # 19491001
|
179 |
+
|
180 |
+
state['direction_offset'] = [int(evt.index[0]), int(evt.index[1])]
|
181 |
+
cond = pipe._encode_prompt(prompt, device, 1, True, '')
|
182 |
+
for _ in range(int(count)):
|
183 |
+
image = state['img']
|
184 |
+
img_tensor = torch.from_numpy(np.array(image) / 255.).to(device).to(torch_dtype).permute(2,0,1).unsqueeze(0)
|
185 |
+
_,_,depth = pipe.midas_model(np.array(image))
|
186 |
+
|
187 |
+
centered = is_centered(state['direction_offset'])
|
188 |
+
if centered:
|
189 |
+
extrinsic = compute_extrinsic_parameters(state['direction_offset'], depth, intrinsic, rotation_speed=float(0), step_z=float(speed_z), step_x=float(speed_x), step_y=float(speed_y))
|
190 |
+
state['centered'] = centered
|
191 |
+
else:
|
192 |
+
extrinsic = compute_extrinsic_parameters(state['direction_offset'], depth, intrinsic, rotation_speed=float(speed_r), step_z=float(speed_z), step_x=float(speed_x), step_y=float(speed_y))
|
193 |
+
|
194 |
+
this_latent = encode_imgs(img_tensor)
|
195 |
+
this_ddim_inv_noise_end = ddim_inversion(this_latent, cond[1:], stop_t=end_t)
|
196 |
+
this_ddim_inv_noise_start = ddim_inversion(this_latent, cond[1:], stop_t=startstart_t)
|
197 |
+
|
198 |
+
wrapped_this_ddim_inv_noise_end = pipe.midas_model.wrap_img_tensor_w_fft_ext(this_ddim_inv_noise_end.to(torch_dtype),
|
199 |
+
torch.from_numpy(depth).to(device).to(torch_dtype),
|
200 |
+
intrinsic,
|
201 |
+
extrinsic[:3,:3], extrinsic[:3,3], threshold=threshold).to(torch_dtype)
|
202 |
+
|
203 |
+
wrapped_this_ddim_inv_noise_start = ddim_inversion(wrapped_this_ddim_inv_noise_end, cond[1:], stop_t=start_t, start_t=end_t,)
|
204 |
+
wrapped_this_ddim_inv_noise_start = DDPM_forward(wrapped_this_ddim_inv_noise_start, t_start=start_t, delta_t=(startstart_id-start_id)*20,
|
205 |
+
ddpm_scheduler=ddpm_scheduler, generator=generator)
|
206 |
+
|
207 |
+
new_img = pipe.denoise_w_injection(
|
208 |
+
prompt, generator=generator, num_inference_steps=num_inference_steps,
|
209 |
+
latents=torch.cat([this_ddim_inv_noise_start, wrapped_this_ddim_inv_noise_start], dim=0), t_start=startstart_t,
|
210 |
+
latent_mask=torch.ones_like(this_latent[0,0,...], device=device,
|
211 |
+
).unsqueeze(0),
|
212 |
+
f=0, attn=attn, guidance_scale=7.5, negative_prompt=neg_prompt,
|
213 |
+
guidance_loss_scale=guidance_loss_scale, early_stop=early_stop, up_ft_indexes=[up_ft_indexes],
|
214 |
+
cfg_norm=cfg_norm, cfg_decay=cfg_decay, lr=lr,
|
215 |
+
intrinsic=intrinsic, extrinsic=extrinsic, threshold=threshold,depth=depth,
|
216 |
+
).images[1]
|
217 |
+
|
218 |
+
new_img = np.array(new_img).astype(np.uint8)
|
219 |
+
state['img'] = new_img
|
220 |
+
|
221 |
+
state['img_his'].append(new_img)
|
222 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 1.
|
223 |
+
state['depth_his'].append(depth)
|
224 |
+
|
225 |
+
return new_img, depth, state['img_his'], state
|
226 |
+
|
227 |
+
def is_centered(clicked_point, image_dimensions=(512, 512), threshold=5):
|
228 |
+
image_center = [dim // 2 for dim in image_dimensions]
|
229 |
+
return all(abs(clicked_point[i] - image_center[i]) <= threshold for i in range(2))
|
230 |
+
|
231 |
+
|
232 |
+
def gen_img(prompt, neg_prompt, state, seed):
|
233 |
+
generator = torch.Generator(device).manual_seed(int(seed)) # 19491001
|
234 |
+
img = pipe(
|
235 |
+
prompt, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=neg_prompt,
|
236 |
+
).images[0]
|
237 |
+
img_array = np.array(img)
|
238 |
+
_,_,depth = pipe.midas_model(img_array)
|
239 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 1.
|
240 |
+
|
241 |
+
state['img_his'] = [img_array]
|
242 |
+
state['depth_his'] = [depth]
|
243 |
+
try:
|
244 |
+
state['ori_img'] = img_array
|
245 |
+
state['img'] = img_array
|
246 |
+
except Exception:
|
247 |
+
ipdb.set_trace()
|
248 |
+
return img_array, depth, [img_array], state
|
249 |
+
|
250 |
+
def on_undo(state):
|
251 |
+
if len(state['img_his'])>1:
|
252 |
+
del state['img_his'][-1]
|
253 |
+
del state['depth_his'][-1]
|
254 |
+
image = state['img_his'][-1]
|
255 |
+
depth = state['depth_his'][-1]
|
256 |
+
else:
|
257 |
+
image = state['img_his'][-1]
|
258 |
+
depth = state['depth_his'][-1]
|
259 |
+
state['img'] = image
|
260 |
+
return image, depth, state['img_his'], state
|
261 |
+
|
262 |
+
def on_reset(state):
|
263 |
+
image = state['img_his'][0]
|
264 |
+
depth = state['depth_his'][0]
|
265 |
+
state['img'] = image
|
266 |
+
state['img_his'] = [image]
|
267 |
+
state['depth_his'] = [depth]
|
268 |
+
return image, depth, state['img_his'], state
|
269 |
+
|
270 |
+
def get_prompt(text):
|
271 |
+
return text
|
272 |
+
|
273 |
+
def on_save(state, video_name):
|
274 |
+
save_video(state['img_his'], fps=5, out_path=f'output/{video_name}.mp4')
|
275 |
+
|
276 |
+
def on_seed(seed):
|
277 |
+
return int(seed)
|
278 |
+
|
279 |
+
def main(args):
|
280 |
+
with gr.Blocks() as demo:
|
281 |
+
gr.Markdown(
|
282 |
+
"""
|
283 |
+
# DreamDrone
|
284 |
+
|
285 |
+
Official implementation of [DreamDrone](https://hyokong.github.io/publications/dreamdrone-page/).
|
286 |
+
|
287 |
+
**TL;DR:** Navigate dreamscapes with a ***click*** – your chosen point guides the drone's flight in a thrilling visual journey.
|
288 |
+
|
289 |
+
## Tutorial
|
290 |
+
|
291 |
+
1. Enter your prompt (and a negative prompt, if necessary) in the textbox, then click the `Generate first image` button.
|
292 |
+
2. Adjust the camera's moving speed in the `Direction` panel and set hyperparameters in the `Hyper params` panel.
|
293 |
+
3. Click on the generated image to make the camera fly towards the clicked direction.
|
294 |
+
4. The generated images will be displayed in the gallery at the bottom. You can view these images by clicking on them in the gallery or by using the left/right arrow buttons.
|
295 |
+
|
296 |
+
## Hints
|
297 |
+
|
298 |
+
- You can set the number of images to generate after clicking on an image, for convenience.
|
299 |
+
- Our system uses a right-hand coordinate system, with the Z-axis pointing into the image.
|
300 |
+
- The rotation speed determines how quickly the camera moves towards the clicked direction (rotation only, no translation). Increase this if you need faster camera pose changes.
|
301 |
+
- The Speed XYZ-axis controls the camera's movement along the X, Y, and Z axes. Adjust these parameters for different movement styles, similar to a camera arm.
|
302 |
+
- $t_1$ represents the timestep that wraps the latent code.
|
303 |
+
- Noise is added from $t_1$ to $t_3$. Between $t_1$ and $t_2$, noise is sourced from a pretrained diffusion U-Net. From $t_2$ to $t_3$, random Gaussian noise is used.
|
304 |
+
- The `Learning rate` and `Feature Correspondence Guidance` control the feature-correspondence guidance weight during the denoising process (from timestep $t_3$ to $0$).
|
305 |
+
- The `KV injection` parameter adjusts the extent of key and value injection from the current frame to the next.
|
306 |
+
|
307 |
+
> If you encounter any problems, please open an issue. Also, don't forget to star the [Official Github Repo](https://github.com/HyoKong/DreamDrone).
|
308 |
+
|
309 |
+
***Without further ado, welcome to DreamDrone – enjoy piloting your virtual drone through imaginative landscapes!***
|
310 |
+
|
311 |
+
|
312 |
+
""",
|
313 |
+
)
|
314 |
+
img = np.zeros((512, 512, 3)).astype(np.uint8)
|
315 |
+
depth_img = np.zeros((512, 512, 3)).astype(np.uint8)
|
316 |
+
intrinsic_matrix = np.array([[1000, 0, 512/2],
|
317 |
+
[0, 1000, 512/2],
|
318 |
+
[0, 0, 1]]) # Example intrinsic matrix
|
319 |
+
extrinsic_matrix = np.array([[1.0, 0.0, 0.0, 0.0],
|
320 |
+
[0.0, 1.0, 0.0, 0.0],
|
321 |
+
[0.0, 0.0, 1.0, 0.0]],
|
322 |
+
dtype=np.float32)
|
323 |
+
direction_offset = (255, 255)
|
324 |
+
state = gr.State({
|
325 |
+
'ori_img': img,
|
326 |
+
'img': None,
|
327 |
+
'centered': False,
|
328 |
+
'img_his': [],
|
329 |
+
'depth_his': [],
|
330 |
+
'intrinsic': intrinsic_matrix,
|
331 |
+
'extrinsic': extrinsic_matrix,
|
332 |
+
'direction_offset': direction_offset
|
333 |
+
})
|
334 |
+
|
335 |
+
with gr.Row():
|
336 |
+
with gr.Column(scale=0.2):
|
337 |
+
with gr.Accordion("Direction"):
|
338 |
+
speed_r = gr.Number(value=0.1, label='Rotation Speed', step=0.01, minimum=0, maximum=1)
|
339 |
+
speed_x = gr.Number(value=0, label='Speed X-axis', step=1, minimum=-10, maximum=20.0)
|
340 |
+
speed_y = gr.Number(value=0, label='Speed Y-axis', step=1, minimum=-10, maximum=20.0)
|
341 |
+
speed_z = gr.Number(value=5, label='Speed Z-axis', step=1, minimum=-10, maximum=20.0)
|
342 |
+
with gr.Accordion('Hyper params'):
|
343 |
+
with gr.Row():
|
344 |
+
count = gr.Number(value=5, label='Num. of generated images', step=1, minimum=1, maximum=10, precision=0)
|
345 |
+
seed = gr.Number(value=19491000, label='Seed', precision=0)
|
346 |
+
t1 = gr.Slider(1, 49, 2, step=1, label='t1')
|
347 |
+
t2 = gr.Slider(1, 49, 12, step=1, label='t2')
|
348 |
+
t3 = gr.Slider(1, 49, 27, step=1, label='t3')
|
349 |
+
lr = gr.Slider(0, 500, 300, step=1, label='Learning rate')
|
350 |
+
guidance_weight = gr.Slider(0, 10, 0.1, step=0.1, label='Feature correspondance guidance')
|
351 |
+
attn = gr.Slider(0, 1, 0.5, step=0.1, label='KV injection')
|
352 |
+
threshold = gr.Slider(0, 31, 20, step=1, label='Threshold of low-pass filter')
|
353 |
+
early_stop = gr.Slider(0, 50, 48, step=1, label='Early stop timestep for feature-correspondance guidance')
|
354 |
+
video_name = gr.Textbox(
|
355 |
+
label="Saved video name", show_label=True, max_lines=1, placeholder='saved video name', value='output',
|
356 |
+
).style()
|
357 |
+
|
358 |
+
with gr.Column():
|
359 |
+
with gr.Box():
|
360 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
361 |
+
text = gr.Textbox(
|
362 |
+
label="Enter your prompt", show_label=False, max_lines=1, placeholder='Enter your prompt', value='Backyards of Old Houses in Antwerp in the Snow, van Gogh',
|
363 |
+
).style(
|
364 |
+
border=(True, False, True, True),
|
365 |
+
rounded=(True, False, False, True),
|
366 |
+
container=False,
|
367 |
+
)
|
368 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
369 |
+
with gr.Column(scale=0.8):
|
370 |
+
neg_text = gr.Textbox(
|
371 |
+
label="Enter your negative prompt", show_label=False, max_lines=1, value='', placeholder='Enter your negative prompt',
|
372 |
+
).style(
|
373 |
+
border=(True, False, True, True),
|
374 |
+
rounded=(True, False, False, True),
|
375 |
+
container=False,
|
376 |
+
)
|
377 |
+
with gr.Column(scale=0.2):
|
378 |
+
gen_btn = gr.Button("Generate first image").style(
|
379 |
+
margin=False,
|
380 |
+
rounded=(False, True, True, False),
|
381 |
+
)
|
382 |
+
|
383 |
+
with gr.Box():
|
384 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
385 |
+
with gr.Column():
|
386 |
+
with gr.Tab('Current view'):
|
387 |
+
image = gr.Image(img).style(height=600, width=600)
|
388 |
+
with gr.Column():
|
389 |
+
with gr.Tab('Depth'):
|
390 |
+
depth_image = gr.Image(depth_img).style(height=600, width=600)
|
391 |
+
with gr.Row():
|
392 |
+
with gr.Column(min_width=100):
|
393 |
+
reset_btn = gr.Button('Clear All')
|
394 |
+
with gr.Column(min_width=100):
|
395 |
+
undo_btn = gr.Button('Undo Last')
|
396 |
+
with gr.Column(min_width=100):
|
397 |
+
save_btn = gr.Button('Save Video')
|
398 |
+
with gr.Row():
|
399 |
+
with gr.Tab('Generated image gallery'):
|
400 |
+
gallery = gr.Gallery(
|
401 |
+
label='Generated images', show_label=False, elem_id='gallery', preview=True, rows=1, height=368,
|
402 |
+
).style()
|
403 |
+
|
404 |
+
image.select(on_click, [state, seed, count, text, neg_text, speed_r, speed_x, speed_y, speed_z, t1, t2, t3, lr, guidance_weight,attn,threshold, early_stop], [image, depth_image, gallery, state])
|
405 |
+
text.submit(get_prompt, inputs=[text], outputs=[text])
|
406 |
+
neg_text.submit(get_prompt, inputs=[neg_text], outputs=[neg_text])
|
407 |
+
gen_btn.click(gen_img, inputs=[text, neg_text, state, seed], outputs=[image, depth_image, gallery, state])
|
408 |
+
reset_btn.click(on_reset, inputs=[state], outputs=[image, depth_image, gallery, state])
|
409 |
+
undo_btn.click(on_undo, inputs=[state], outputs=[image, depth_image, gallery, state])
|
410 |
+
save_btn.click(on_save, inputs=[state, video_name], outputs=[])
|
411 |
+
|
412 |
+
global num_inference_steps
|
413 |
+
global pipe
|
414 |
+
global intrinsic
|
415 |
+
global ddim_scheduler
|
416 |
+
global ddpm_scheduler
|
417 |
+
global device
|
418 |
+
global model_id
|
419 |
+
global torch_dtype
|
420 |
+
|
421 |
+
num_inference_steps = 50
|
422 |
+
|
423 |
+
device = args.device
|
424 |
+
model_id = args.model_id
|
425 |
+
ddim_scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
426 |
+
ddpm_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
427 |
+
torch_dtype=torch.float16 if 'cuda' in str(device) else torch.float32
|
428 |
+
|
429 |
+
pipe = DDIMBackward.from_pretrained(
|
430 |
+
model_id, scheduler=ddim_scheduler, torch_dtype=torch_dtype,
|
431 |
+
cache_dir='.', device=str(device), model_id=model_id, depth_model=args.depth_model,
|
432 |
+
).to(str(device))
|
433 |
+
|
434 |
+
if 'cuda' in str(device):
|
435 |
+
pipe.enable_attention_slicing()
|
436 |
+
pipe.enable_xformers_memory_efficient_attention()
|
437 |
+
|
438 |
+
intrinsic = np.array([[1000, 0, 256],
|
439 |
+
[0, 1000., 256],
|
440 |
+
[0, 0, 1]]) # Example intrinsic matrix
|
441 |
+
return demo
|
442 |
+
|
443 |
+
|
444 |
+
if __name__ == '__main__':
|
445 |
+
import argparse
|
446 |
+
parser = argparse.ArgumentParser()
|
447 |
+
parser.add_argument('--device', default='cuda')
|
448 |
+
parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base')
|
449 |
+
parser.add_argument('--depth_model', default='dpt_beit_large_512', choices=['dpt_beit_large_512', 'dpt_swin2_large_384'])
|
450 |
+
parser.add_argument('--share', action='store_true')
|
451 |
+
parser.add_argument('-p', '--port', type=int, default=None)
|
452 |
+
parser.add_argument('--ip', default=None)
|
453 |
+
args = parser.parse_args()
|
454 |
+
demo = main(args)
|
455 |
+
print('Successfully loaded, starting gradio demo')
|
456 |
+
demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)
|