JoPmt commited on
Commit
0c4a0e4
·
verified ·
1 Parent(s): e9503d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -63
app.py CHANGED
@@ -117,60 +117,6 @@ os.makedirs("./gradio_tmp", exist_ok=True)
117
  upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
118
  frame_interpolation_model = load_rife_model("model_rife")
119
 
120
- @spaces.GPU(duration=65)
121
- def infer(
122
- prompt: str,
123
- image_input: str,
124
- num_inference_steps: int,
125
- guidance_scale: float,
126
- seed: int = 42,
127
- progress=gr.Progress(track_tqdm=True),
128
- ):
129
- if seed == -1:
130
- seed = random.randint(0, 2**8 - 1)
131
-
132
- id_image = np.array(ImageOps.exif_transpose(Image.fromarray(image_input)).convert("RGB"))
133
- id_image = resize_numpy_image_long(id_image, 1024)
134
- id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
135
- eva_transform_mean, eva_transform_std,
136
- face_main_model, device, dtype, id_image,
137
- original_id_image=id_image, is_align_face=True,
138
- cal_uncond=False)
139
-
140
- if is_kps:
141
- kps_cond = face_kps
142
- else:
143
- kps_cond = None
144
-
145
- tensor = align_crop_face_image.cpu().detach()
146
- tensor = tensor.squeeze()
147
- tensor = tensor.permute(1, 2, 0)
148
- tensor = tensor.numpy() * 255
149
- tensor = tensor.astype(np.uint8)
150
- image = ImageOps.exif_transpose(Image.fromarray(tensor))
151
-
152
- prompt = prompt.strip('"')
153
-
154
- generator = torch.Generator(device).manual_seed(seed) if seed else None
155
-
156
- video_pt = pipe(
157
- prompt=prompt,
158
- image=image,
159
- num_videos_per_prompt=1,
160
- num_inference_steps=num_inference_steps,
161
- num_frames=49,
162
- use_dynamic_cfg=False,
163
- guidance_scale=guidance_scale,
164
- generator=generator,
165
- id_vit_hidden=id_vit_hidden,
166
- id_cond=id_cond,
167
- kps_cond=kps_cond,
168
- output_type="pt",
169
- ).frames
170
-
171
- ##free_memory()
172
- return video_pt, seed
173
-
174
 
175
  def convert_to_gif(video_path):
176
  clip = VideoFileClip(video_path)
@@ -196,7 +142,7 @@ def delete_old_files():
196
 
197
 
198
  ##threading.Thread(target=delete_old_files, daemon=True).start()
199
- @spaces.GPU
200
  def generate(
201
  prompt,
202
  image_input,
@@ -205,14 +151,40 @@ def generate(
205
  rife_status,
206
  progress=gr.Progress(track_tqdm=True)
207
  ):
208
- latents, seed = infer(
209
- prompt,
210
- image_input,
211
- num_inference_steps=4,
212
- guidance_scale=7.0,
213
- seed=seed_value,
214
- progress=progress,
215
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  if scale_status:
217
  latents = upscale_batch_and_concatenate(upscale_model, latents, device)
218
  if rife_status:
 
117
  upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
118
  frame_interpolation_model = load_rife_model("model_rife")
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def convert_to_gif(video_path):
122
  clip = VideoFileClip(video_path)
 
142
 
143
 
144
  ##threading.Thread(target=delete_old_files, daemon=True).start()
145
+ @spaces.GPU(duration=65)
146
  def generate(
147
  prompt,
148
  image_input,
 
151
  rife_status,
152
  progress=gr.Progress(track_tqdm=True)
153
  ):
154
+ def infer(prompt: str,image_input: str,num_inference_steps: int,guidance_scale: float,seed: int = 42,progress=gr.Progress(track_tqdm=True),):
155
+ if seed == -1:
156
+ seed = random.randint(0, 2**8 - 1)
157
+
158
+ id_image = np.array(ImageOps.exif_transpose(Image.fromarray(image_input)).convert("RGB"))
159
+ id_image = resize_numpy_image_long(id_image, 1024)
160
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
161
+ eva_transform_mean, eva_transform_std,
162
+ face_main_model, device, dtype, id_image,
163
+ original_id_image=id_image, is_align_face=True,
164
+ cal_uncond=False)
165
+
166
+ if is_kps:
167
+ kps_cond = face_kps
168
+ else:
169
+ kps_cond = None
170
+
171
+ tensor = align_crop_face_image.cpu().detach()
172
+ tensor = tensor.squeeze()
173
+ tensor = tensor.permute(1, 2, 0)
174
+ tensor = tensor.numpy() * 255
175
+ tensor = tensor.astype(np.uint8)
176
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
177
+
178
+ prompt = prompt.strip('"')
179
+
180
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
181
+
182
+ video_pt = pipe(prompt=prompt,image=image,num_videos_per_prompt=1,num_inference_steps=num_inference_steps,num_frames=49,use_dynamic_cfg=False,guidance_scale=guidance_scale,generator=generator,id_vit_hidden=id_vit_hidden,id_cond=id_cond,kps_cond=kps_cond,output_type="pt",).frames
183
+
184
+ ##free_memory()
185
+ return video_pt, seed
186
+
187
+ latents, seed = infer(prompt,image_input,num_inference_steps=4,guidance_scale=7.0,seed=seed_value,progress=progress,)
188
  if scale_status:
189
  latents = upscale_batch_and_concatenate(upscale_model, latents, device)
190
  if rife_status: