Spaces:
Paused
Paused
import os | |
import cv2 | |
import torch | |
from gfpgan import GFPGANer | |
from tqdm import tqdm | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
def load_video_to_cv2(input_path): | |
video_stream = cv2.VideoCapture(input_path) | |
fps = video_stream.get(cv2.CAP_PROP_FPS) | |
full_frames = [] | |
while True: | |
still_reading, frame = video_stream.read() | |
if not still_reading: | |
video_stream.release() | |
break | |
full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
return full_frames, fps | |
def save_frames_to_video(frames, output_path, fps): | |
if len(frames) == 0: | |
raise ValueError("No frames to write to video.") | |
height, width, _ = frames[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
for frame in frames: | |
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
video_writer.release() | |
def process_video_with_gfpgan(input_video_path, output_video_path, model_path='gfpgan/weights/GFPGANv1.4.pth'): | |
# Load video and convert to frames | |
frames, fps = load_video_to_cv2(input_video_path) | |
realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
bg_upsampler = RealESRGANer( | |
scale=2, | |
model_path="gfpgan/weights/RealESRGAN_x2plus.pth", | |
model=realesrgan_model, | |
tile=400, | |
tile_pad=10, | |
pre_pad=0, | |
half=True) | |
# Set up GFPGAN restorer | |
arch = 'clean' | |
channel_multiplier = 2 | |
restorer = GFPGANer( | |
model_path=model_path, | |
upscale=2, | |
arch=arch, | |
channel_multiplier=channel_multiplier, | |
bg_upsampler=bg_upsampler | |
) | |
# Enhance each frame | |
enhanced_frames = [] | |
print("Enhancing frames...") | |
for frame in tqdm(frames, desc='Processing Frames'): | |
# Enhance face in the frame | |
img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
_, _, enhanced_img = restorer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
enhanced_frames.append(cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)) | |
# Save the enhanced frames to a video | |
save_frames_to_video(enhanced_frames, output_video_path, fps) | |
print(f'Enhanced video saved at {output_video_path}') |