Friedrich-M commited on
Commit
b2d9087
·
1 Parent(s): 038a20f

update 11.22

Browse files
Files changed (2) hide show
  1. app.py +730 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ import mediapipe as mp
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionControlNetInpaintPipeline
10
+ from transformers import AutoTokenizer
11
+ import base64
12
+ import requests
13
+ import json
14
+ from rembg import remove
15
+ from scipy import ndimage
16
+ from moviepy.editor import ImageSequenceClip
17
+ from tqdm import tqdm
18
+ import os
19
+ import shutil
20
+ import time
21
+ from huggingface_hub import snapshot_download
22
+ import subprocess
23
+ import sys
24
+
25
+
26
+ @spaces.GPU(duration=120)
27
+ def download_liveportrait():
28
+ """
29
+ Clone the LivePortrait repository and prepare its dependencies.
30
+ """
31
+ liveportrait_path = "./LivePortrait"
32
+ try:
33
+ if not os.path.exists(liveportrait_path):
34
+ print("Cloning LivePortrait repository...")
35
+ os.system(f"git clone https://github.com/KwaiVGI/LivePortrait.git {liveportrait_path}")
36
+
37
+ # 安装依赖
38
+ os.chdir(liveportrait_path)
39
+ print("Installing LivePortrait dependencies...")
40
+ os.system("pip install -r requirements.txt")
41
+
42
+ # 构建 MultiScaleDeformableAttention 模块
43
+ dependency_path = "src/utils/dependencies/XPose/models/UniPose/ops"
44
+ os.chdir(dependency_path)
45
+ print("Building MultiScaleDeformableAttention...")
46
+ os.system("python setup.py build")
47
+ os.system("python setup.py install")
48
+
49
+ # 确保模块路径可用
50
+ module_path = os.path.abspath(dependency_path)
51
+ if module_path not in sys.path:
52
+ sys.path.append(module_path)
53
+
54
+ # 返回 LivePortrait 目录
55
+ os.chdir("../../../../../../../")
56
+ print("LivePortrait setup completed")
57
+ except Exception as e:
58
+ print("Failed to initialize LivePortrait:", e)
59
+ raise
60
+ download_liveportrait()
61
+
62
+ @spaces.GPU(duration=120)
63
+ def download_huggingface_resources():
64
+ """
65
+ Download additional necessary resources from Hugging Face using the CLI.
66
+ """
67
+ try:
68
+ local_dir = "./pretrained_weights"
69
+ os.makedirs(local_dir, exist_ok=True)
70
+
71
+ # Use the Hugging Face CLI for downloading
72
+ cmd = [
73
+ "huggingface-cli", "download",
74
+ "KwaiVGI/LivePortrait",
75
+ "--local-dir", local_dir,
76
+ "--exclude", "*.git*", "README.md", "docs"
77
+ ]
78
+ print("Executing command:", " ".join(cmd))
79
+ subprocess.run(cmd, check=True)
80
+
81
+ print("Resources successfully downloaded to:", local_dir)
82
+ except subprocess.CalledProcessError as e:
83
+ print("Error during Hugging Face CLI download:", e)
84
+ raise
85
+ except Exception as e:
86
+ print("General error in downloading resources:", e)
87
+ raise
88
+
89
+ download_huggingface_resources()
90
+
91
+
92
+ @spaces.GPU(duration=120)
93
+ def get_project_root():
94
+ """Get the root directory of the current project."""
95
+ return os.path.abspath(os.path.dirname(__file__))
96
+
97
+ # Ensure working directory is project root
98
+ os.chdir(get_project_root())
99
+
100
+ # Initialize the necessary models and components
101
+ mp_pose = mp.solutions.pose
102
+ mp_drawing = mp.solutions.drawing_utils
103
+
104
+ # Load ControlNet model
105
+ controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-openpose', torch_dtype=torch.float16)
106
+
107
+ # Load Stable Diffusion model with ControlNet
108
+ pipe_controlnet = StableDiffusionControlNetPipeline.from_pretrained(
109
+ 'runwayml/stable-diffusion-v1-5',
110
+ controlnet=controlnet,
111
+ torch_dtype=torch.float16
112
+ )
113
+
114
+ # Load Inpaint Controlnet
115
+ pipe_inpaint_controlnet = StableDiffusionControlNetInpaintPipeline.from_pretrained(
116
+ "runwayml/stable-diffusion-inpainting",
117
+ controlnet=controlnet,
118
+ torch_dtype=torch.float16
119
+ )
120
+
121
+ # Move to GPU if available
122
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
123
+ pipe_controlnet.to(device)
124
+ pipe_controlnet.enable_attention_slicing()
125
+ pipe_inpaint_controlnet.to(device)
126
+ pipe_inpaint_controlnet.enable_attention_slicing()
127
+
128
+
129
+ @spaces.GPU(duration=120)
130
+ def resize_to_multiple_of_64(width, height):
131
+ return (width // 64) * 64, (height // 64) * 64
132
+
133
+
134
+ @spaces.GPU(duration=120)
135
+ def expand_mask(mask, kernel_size):
136
+ mask_array = np.array(mask)
137
+ structuring_element = np.ones((kernel_size, kernel_size), dtype=np.uint8)
138
+ expanded_mask_array = ndimage.binary_dilation(
139
+ mask_array, structure=structuring_element
140
+ ).astype(np.uint8) * 255
141
+ return Image.fromarray(expanded_mask_array)
142
+
143
+
144
+ @spaces.GPU(duration=120)
145
+ def crop_face_to_square(image_rgb, padding_ratio=0.2):
146
+ """
147
+ Detects the face in the input image and crops an enlarged square region around it.
148
+ """
149
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
150
+ gray_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
151
+ faces = face_cascade.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
152
+
153
+ if len(faces) == 0:
154
+ print("No face detected.")
155
+ return None
156
+
157
+ x, y, w, h = faces[0]
158
+ center_x, center_y = x + w // 2, y + h // 2
159
+ side_length = max(w, h)
160
+ padded_side_length = int(side_length * (1 + padding_ratio))
161
+ half_side = padded_side_length // 2
162
+
163
+ top_left_x = max(center_x - half_side, 0)
164
+ top_left_y = max(center_y - half_side, 0)
165
+ bottom_right_x = min(center_x + half_side, image_rgb.shape[1])
166
+ bottom_right_y = min(center_y + half_side, image_rgb.shape[0])
167
+
168
+ cropped_image = image_rgb[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
169
+ resized_image = cv2.resize(cropped_image, (768, 768), interpolation=cv2.INTER_AREA)
170
+
171
+ return resized_image
172
+
173
+
174
+ @spaces.GPU(duration=120)
175
+ def spirit_animal_baseline(image_path, num_images = 4):
176
+
177
+ image = cv2.imread(image_path)
178
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
179
+
180
+ image_rgb = crop_face_to_square(image_rgb)
181
+
182
+ original_height, original_width, _ = image_rgb.shape
183
+ aspect_ratio = original_width / original_height
184
+
185
+ if aspect_ratio > 1:
186
+ gen_width = 768
187
+ gen_height = int(gen_width / aspect_ratio)
188
+ else:
189
+ gen_height = 768
190
+ gen_width = int(gen_height * aspect_ratio)
191
+
192
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
193
+
194
+ with mp_pose.Pose(static_image_mode=True) as pose:
195
+ results = pose.process(image_rgb)
196
+
197
+ if results.pose_landmarks:
198
+ annotated_image = image_rgb.copy()
199
+ mp_drawing.draw_landmarks(
200
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
201
+ )
202
+ else:
203
+ print("No pose detected.")
204
+ return "No pose detected.", []
205
+
206
+ pose_image = np.zeros_like(image_rgb)
207
+ for connection in mp_pose.POSE_CONNECTIONS:
208
+ start_idx, end_idx = connection
209
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
210
+ if start.visibility > 0.5 and end.visibility > 0.5:
211
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
212
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
213
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
214
+
215
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
216
+
217
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
218
+ api_key = "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"
219
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
220
+ payload = {
221
+ "model": "gpt-4o-mini",
222
+ "messages": [
223
+ {
224
+ "role": "user",
225
+ "content": [
226
+ {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."},
227
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
228
+ ]
229
+ }
230
+ ],
231
+ "max_tokens": 100
232
+ }
233
+
234
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
235
+ prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal"
236
+
237
+ num_images = num_images
238
+ generated_images = []
239
+ with torch.no_grad():
240
+ with torch.autocast(device_type=device.type):
241
+ for _ in range(num_images):
242
+ images = pipe_controlnet(
243
+ prompt=prompt,
244
+ negative_prompt="multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, blurry",
245
+ num_inference_steps=20,
246
+ image=pose_pil,
247
+ guidance_scale=5,
248
+ width=gen_width,
249
+ height=gen_height,
250
+ ).images
251
+ generated_images.append(images[0])
252
+
253
+ return prompt, generated_images
254
+
255
+
256
+ @spaces.GPU(duration=120)
257
+ def spirit_animal_with_background(image_path, num_images = 4):
258
+
259
+ image = cv2.imread(image_path)
260
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
261
+
262
+ # image_rgb = crop_face_to_square(image_rgb)
263
+
264
+ original_height, original_width, _ = image_rgb.shape
265
+ aspect_ratio = original_width / original_height
266
+
267
+ if aspect_ratio > 1:
268
+ gen_width = 768
269
+ gen_height = int(gen_width / aspect_ratio)
270
+ else:
271
+ gen_height = 768
272
+ gen_width = int(gen_height * aspect_ratio)
273
+
274
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
275
+
276
+ with mp_pose.Pose(static_image_mode=True) as pose:
277
+ results = pose.process(image_rgb)
278
+
279
+ if results.pose_landmarks:
280
+ annotated_image = image_rgb.copy()
281
+ mp_drawing.draw_landmarks(
282
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
283
+ )
284
+ else:
285
+ print("No pose detected.")
286
+ return "No pose detected.", []
287
+
288
+ pose_image = np.zeros_like(image_rgb)
289
+ for connection in mp_pose.POSE_CONNECTIONS:
290
+ start_idx, end_idx = connection
291
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
292
+ if start.visibility > 0.5 and end.visibility > 0.5:
293
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
294
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
295
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
296
+
297
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
298
+
299
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
300
+ api_key = "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"
301
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
302
+ payload = {
303
+ "model": "gpt-4o-mini",
304
+ "messages": [
305
+ {
306
+ "role": "user",
307
+ "content": [
308
+ {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."},
309
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
310
+ ]
311
+ }
312
+ ],
313
+ "max_tokens": 100
314
+ }
315
+
316
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
317
+ prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal"
318
+
319
+ mask_image = remove(Image.fromarray(image_rgb))
320
+ initial_mask = mask_image.split()[-1].convert('L')
321
+
322
+ kernel_size = min(gen_width, gen_height) // 15
323
+ expanded_mask = expand_mask(initial_mask, kernel_size)
324
+
325
+ num_images = num_images
326
+ generated_images = []
327
+ with torch.no_grad():
328
+ with torch.autocast(device_type=device.type):
329
+ for _ in range(num_images):
330
+ images = pipe_inpaint_controlnet(
331
+ prompt=prompt,
332
+ negative_prompt="multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, blurry",
333
+ num_inference_steps=20,
334
+ image=Image.fromarray(image_rgb),
335
+ mask_image=expanded_mask,
336
+ control_image=pose_pil,
337
+ width=gen_width,
338
+ height=gen_height,
339
+ guidance_scale=5,
340
+ ).images
341
+ generated_images.append(images[0])
342
+
343
+ return prompt, generated_images
344
+
345
+
346
+ @spaces.GPU(duration=120)
347
+ def generate_multiple_animals(image_path, keep_background=True, num_images = 4):
348
+
349
+ image = cv2.imread(image_path)
350
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
351
+
352
+ image_rgb = crop_face_to_square(image_rgb)
353
+
354
+ original_image = Image.fromarray(image_rgb)
355
+ original_width, original_height = original_image.size
356
+
357
+ aspect_ratio = original_width / original_height
358
+ if aspect_ratio > 1:
359
+ gen_width = 768
360
+ gen_height = int(gen_width / aspect_ratio)
361
+ else:
362
+ gen_height = 768
363
+ gen_width = int(gen_height * aspect_ratio)
364
+
365
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
366
+
367
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
368
+ api_key = "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"
369
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
370
+ payload = {
371
+ "model": "gpt-4o-mini",
372
+ "messages": [
373
+ {
374
+ "role": "user",
375
+ "content": [
376
+ {
377
+ "type": "text",
378
+ "text": "Based on the provided image, think of " + str(num_images) + " different spirit animals that are right for the person, and answer in the following format for each: An ultra-realistic, highly detailed photograph of a {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate these sentences without any other responses or numbering. For the animal choose between owl, bear, fox, koala, lion, dog"
379
+ },
380
+ {
381
+ "type": "image_url",
382
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
383
+ }
384
+ ]
385
+ }
386
+ ],
387
+ "max_tokens": 500
388
+ }
389
+
390
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
391
+ response_json = response.json()
392
+
393
+ if 'choices' in response_json and len(response_json['choices']) > 0:
394
+ content = response_json['choices'][0]['message']['content']
395
+ prompts = [prompt.strip() for prompt in content.strip().split('.') if prompt.strip()]
396
+ negative_prompt = (
397
+ "multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, "
398
+ "blurry, deformed, text, watermark, logo, low resolution"
399
+ )
400
+ formatted_prompts = "\n".join(f"{i+1}. {prompt}" for i, prompt in enumerate(prompts))
401
+
402
+ with mp_pose.Pose(static_image_mode=True) as pose:
403
+ results = pose.process(image_rgb)
404
+
405
+ if results.pose_landmarks:
406
+ annotated_image = image_rgb.copy()
407
+ mp_drawing.draw_landmarks(
408
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
409
+ )
410
+ else:
411
+ print("No pose detected.")
412
+ return "No pose detected.", []
413
+
414
+ pose_image = np.zeros_like(image_rgb)
415
+ for connection in mp_pose.POSE_CONNECTIONS:
416
+ start_idx, end_idx = connection
417
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
418
+ if start.visibility > 0.5 and end.visibility > 0.5:
419
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
420
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
421
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
422
+
423
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
424
+
425
+ if keep_background:
426
+ mask_image = remove(original_image)
427
+ initial_mask = mask_image.split()[-1].convert('L')
428
+ expanded_mask = expand_mask(initial_mask, kernel_size=min(gen_width, gen_height) // 15)
429
+ else:
430
+ expanded_mask = None
431
+
432
+ generated_images = []
433
+
434
+ if keep_background:
435
+ with torch.no_grad():
436
+ with torch.amp.autocast("cuda"):
437
+ for prompt in prompts:
438
+ images = pipe_inpaint_controlnet(
439
+ prompt=prompt,
440
+ negative_prompt=negative_prompt,
441
+ num_inference_steps=20,
442
+ image=Image.fromarray(image_rgb),
443
+ mask_image=expanded_mask,
444
+ control_image=pose_pil,
445
+ width=gen_width,
446
+ height=gen_height,
447
+ guidance_scale=5,
448
+ ).images
449
+ generated_images.append(images[0])
450
+ else:
451
+ with torch.no_grad():
452
+ with torch.amp.autocast("cuda"):
453
+ for prompt in prompts:
454
+ images = pipe_controlnet(
455
+ prompt=prompt,
456
+ negative_prompt=negative_prompt,
457
+ num_inference_steps=20,
458
+ image=pose_pil,
459
+ guidance_scale=5,
460
+ width=gen_width,
461
+ height=gen_height,
462
+ ).images
463
+ generated_images.append(images[0])
464
+
465
+ return formatted_prompts, generated_images
466
+
467
+
468
+ @spaces.GPU(duration=120)
469
+ def wait_for_file(file_path, timeout=500):
470
+ """
471
+ Wait for a file to be created, with a specified timeout.
472
+ Args:
473
+ file_path (str): The path of the file to wait for.
474
+ timeout (int): Maximum time to wait in seconds.
475
+ Returns:
476
+ bool: True if the file is created, False if timeout occurs.
477
+ """
478
+ start_time = time.time()
479
+ while not os.path.exists(file_path):
480
+ if time.time() - start_time > timeout:
481
+ return False
482
+ time.sleep(0.5) # Check every 0.5 seconds
483
+ return True
484
+
485
+
486
+ @spaces.GPU(duration=120)
487
+ def generate_spirit_animal_video(driving_video_path):
488
+ os.chdir(".")
489
+ try:
490
+ # Step 1: Extract the first frame
491
+ cap = cv2.VideoCapture(driving_video_path)
492
+ if not cap.isOpened():
493
+ print("Error: Unable to open video.")
494
+ return None
495
+
496
+ ret, frame = cap.read()
497
+ cap.release()
498
+ if not ret:
499
+ print("Error: Unable to read the first frame.")
500
+ return None
501
+
502
+ # Save the first frame
503
+ first_frame_path = "./first_frame.jpg"
504
+ cv2.imwrite(first_frame_path, frame)
505
+ print(f"First frame saved to: {first_frame_path}")
506
+
507
+ # Generate spirit animal image
508
+ _, input_image = generate_multiple_animals(first_frame_path, True, 1)
509
+ if input_image is None or not input_image:
510
+ print("Error: Spirit animal generation failed.")
511
+ return None
512
+
513
+ spirit_animal_path = "./animal.jpeg"
514
+ cv2.imwrite(spirit_animal_path, cv2.cvtColor(np.array(input_image[0]), cv2.COLOR_RGB2BGR))
515
+ print(f"Spirit animal image saved to: {spirit_animal_path}")
516
+
517
+ # Step 3: Run inference
518
+ output_path = "./animations/animal--uploaded_video_compressed.mp4"
519
+ script_path = os.path.abspath("./LivePortrait/inference_animals.py")
520
+
521
+ if not os.path.exists(script_path):
522
+ print(f"Error: Inference script not found at {script_path}.")
523
+ return None
524
+
525
+ command = f"python {script_path} -s {spirit_animal_path} -d {driving_video_path} --driving_multiplier 1.75 --no_flag_stitching"
526
+ print(f"Running command: {command}")
527
+ result = os.system(command)
528
+
529
+ if result != 0:
530
+ print(f"Error: Command failed with exit code {result}.")
531
+ return None
532
+
533
+ # Verify output file exists
534
+ if not os.path.exists(output_path):
535
+ print(f"Error: Expected output video not found at {output_path}.")
536
+ return None
537
+
538
+ print(f"Output video generated at: {output_path}")
539
+ return output_path
540
+ except Exception as e:
541
+ print(f"Error occurred: {e}")
542
+ return None
543
+
544
+
545
+ @spaces.GPU(duration=120)
546
+ def generate_spirit_animal(image, animal_type, background):
547
+ if animal_type == "Single Animal":
548
+ if background == "Preserve Background":
549
+ prompt, generated_images = spirit_animal_with_background(image)
550
+ else:
551
+ prompt, generated_images = spirit_animal_baseline(image)
552
+ elif animal_type == "Multiple Animals":
553
+ if background == "Preserve Background":
554
+ prompt, generated_images = generate_multiple_animals(image, keep_background=True)
555
+ else:
556
+ prompt, generated_images = generate_multiple_animals(image, keep_background=False)
557
+ return prompt, generated_images
558
+
559
+
560
+ @spaces.GPU(duration=120)
561
+ def compress_video(input_path, output_path, target_size_mb):
562
+ target_size_bytes = target_size_mb * 1024 * 1024
563
+ temp_output = "./temp_compressed.mp4"
564
+
565
+ cap = cv2.VideoCapture(input_path)
566
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码
567
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
568
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
569
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
570
+
571
+ writer = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
572
+ while cap.isOpened():
573
+ ret, frame = cap.read()
574
+ if not ret:
575
+ break
576
+ writer.write(frame)
577
+
578
+ cap.release()
579
+ writer.release()
580
+
581
+ current_size = os.path.getsize(temp_output)
582
+ if current_size > target_size_bytes:
583
+ bitrate = int(target_size_bytes * 8 / (current_size / target_size_bytes)) # 按比例缩减比特率
584
+ os.system(f"ffmpeg -i {temp_output} -b:v {bitrate} -y {output_path}")
585
+ os.remove(temp_output)
586
+ else:
587
+ shutil.move(temp_output, output_path)
588
+
589
+
590
+ @spaces.GPU(duration=120)
591
+ def process_video(video_file):
592
+
593
+ # # 初始化 LivePortrait
594
+ # try:
595
+ # download_liveportrait()
596
+ # except Exception as e:
597
+ # print("Failed to initialize LivePortrait:", e)
598
+ # return gr.update(value=None, visible=False)
599
+
600
+ # # 下载 Hugging Face 资源
601
+ # try:
602
+ # download_huggingface_resources()
603
+ # except Exception as e:
604
+ # print("Failed to download Hugging Face resources:", e)
605
+ # return gr.update(value=None, visible=False)
606
+
607
+ compressed_path = "./uploaded_video_compressed.mp4"
608
+ compress_video(video_file, compressed_path, target_size_mb=1)
609
+ print(f"Compressed and moved video to: {compressed_path}")
610
+
611
+ output_video_path = "./animations/animal--uploaded_video_compressed.mp4"
612
+
613
+ generate_spirit_animal_video(compressed_path)
614
+
615
+ # Wait until the output video is generated
616
+ timeout = 60000 # Timeout in seconds
617
+ if not wait_for_file(output_video_path, timeout=timeout):
618
+ print("Timeout occurred while waiting for video generation.")
619
+ return gr.update(value=None, visible=False) # Hide output if failed
620
+
621
+ # Return the generated video path
622
+ print(f"Output video is ready: {output_video_path}")
623
+ return gr.update(value=output_video_path, visible=True) # Show video
624
+
625
+
626
+ # Custom CSS styling for the interface
627
+ css = """
628
+ #title-container {
629
+ font-family: 'Arial', sans-serif;
630
+ color: #4a4a4a;
631
+ text-align: center;
632
+ margin-bottom: 20px;
633
+ }
634
+ #title-container h1 {
635
+ font-size: 2.5em;
636
+ font-weight: bold;
637
+ color: #ff9900;
638
+ }
639
+ #title-container h2 {
640
+ font-size: 1.2em;
641
+ color: #6c757d;
642
+ }
643
+ #intro-text {
644
+ font-size: 1em;
645
+ color: #6c757d;
646
+ margin: 50px;
647
+ text-align: center;
648
+ font-style: italic;
649
+ }
650
+ #prompt-output {
651
+ font-family: 'Courier New', monospace;
652
+ color: #5a5a5a;
653
+ font-size: 1.1em;
654
+ padding: 10px;
655
+ background-color: #f9f9f9;
656
+ border: 1px solid #ddd;
657
+ border-radius: 5px;
658
+ margin-top: 10px;
659
+ }
660
+ """
661
+
662
+ # Title and description
663
+ title_html = """
664
+ <div id="title-container">
665
+ <h1>Spirit Animal Generator</h1>
666
+ <h2>Create your unique spirit animal with AI-assisted image generation.</h2>
667
+ </div>
668
+ """
669
+
670
+ description_text = """
671
+ ### Project Overview
672
+ Welcome to the Spirit Animal Generator! This tool leverages advanced AI technologies to create unique visualizations of spirit animals from both videos and images.
673
+ #### Key Features:
674
+ 1. **Video Transformation**: Upload a driving video to generate a creative spirit animal animation.
675
+ 2. **Image Creation**: Upload an image and customize the spirit animal type and background options.
676
+ 3. **AI-Powered Prompting**: OpenAI's GPT generates descriptive prompts for each input.
677
+ 4. **High-Quality Outputs**: Generated using Stable Diffusion and ControlNet for stunning visuals.
678
+ ---
679
+ ### How It Works:
680
+ 1. **Upload Your Media**:
681
+ - Videos: Ensure the file is in MP4 format.
682
+ - Images: Use clear, high-resolution photos for better results.
683
+ 2. **Customize Options**:
684
+ - For images, select the type of animal and background settings.
685
+ 3. **View Your Results**:
686
+ - Videos will be transformed into animations.
687
+ - Images will produce customized visual art along with a generated prompt.
688
+ Discover your spirit animal and let your imagination run wild!
689
+ ---
690
+ """
691
+
692
+ with gr.Blocks() as demo:
693
+ gr.HTML(title_html)
694
+ gr.Markdown(description_text)
695
+
696
+ with gr.Tabs():
697
+ with gr.Tab("Generate Spirit Animal Image"):
698
+ gr.Markdown("Upload an image to generate a spirit animal.")
699
+ with gr.Row():
700
+ with gr.Column(scale=1):
701
+ image_input = gr.Image(type="filepath", label="Upload an image")
702
+ animal_type = gr.Radio(choices=["Single Animal", "Multiple Animals"], label="Animal Type", value="Single Animal")
703
+ background_option = gr.Radio(choices=["Preserve Background", "Don't Preserve Background"], label="Background Option", value="Preserve Background")
704
+ generate_image_button = gr.Button("Generate Image")
705
+ with gr.Column(scale=1):
706
+ generated_prompt = gr.Textbox(label="Generated Prompt")
707
+ generated_gallery = gr.Gallery(label="Generated Images")
708
+
709
+ generate_image_button.click(
710
+ fn=generate_spirit_animal,
711
+ inputs=[image_input, animal_type, background_option],
712
+ outputs=[generated_prompt, generated_gallery],
713
+ )
714
+
715
+ with gr.Tab("Generate Spirit Animal Video"):
716
+ gr.Markdown("Upload a driving video to generate a spirit animal video.")
717
+ with gr.Row():
718
+ with gr.Column(scale=1):
719
+ video_input = gr.Video(label="Upload a driving video (MP4 format)")
720
+ generate_video_button = gr.Button("Generate Video")
721
+ with gr.Column(scale=1):
722
+ video_output = gr.Video(label="Generated Spirit Animal Video")
723
+
724
+ generate_video_button.click(
725
+ fn=process_video,
726
+ inputs=video_input,
727
+ outputs=video_output,
728
+ )
729
+
730
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch == 2.1.2
2
+ torchvision==0.16.2
3
+ torchaudio==2.1.2
4
+ moviepy==1.0.3
5
+ imageio[ffmpeg]
6
+ pillow==10.4.0
7
+ tyro==0.8.5
8
+ onnxruntime-gpu==1.18.1
9
+ onnx==1.16.1
10
+ gradio==4.37.1
11
+ colorama
12
+ ffmpeg-python==0.2.0
13
+ mediapipe
14
+ rembg
15
+ huggingface_hub[cli]
16
+ opencv-python
17
+ matplotlib
18
+ diffusers
19
+ transformers
20
+ accelerate