|
import os |
|
import glob |
|
import requests |
|
import json |
|
import cv2 |
|
import numpy as np |
|
import re |
|
import sys |
|
import torch |
|
from PIL import Image |
|
from pprint import pprint |
|
import base64 |
|
from io import BytesIO |
|
import torchvision.transforms.functional as F |
|
from torchvision.io import read_video, read_image, ImageReadMode |
|
from torchvision.models.optical_flow import Raft_Large_Weights |
|
from torchvision.models.optical_flow import raft_large |
|
from torchvision.io import read_video, read_image, ImageReadMode |
|
from torchvision.utils import flow_to_image |
|
import cv2 |
|
from torchvision.io import write_jpeg |
|
import pickle |
|
|
|
import argparse |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('prompt') |
|
parser.add_argument('--negative-prompt', dest='negative_prompt', default="") |
|
|
|
parser.add_argument('--init-image', dest='init_image', default="./init.png") |
|
parser.add_argument('--input-dir', dest='input_dir', default="./Input_Images") |
|
parser.add_argument('--output-dir', dest='output_dir', default="./output") |
|
|
|
parser.add_argument('--width', default=512, type=int) |
|
parser.add_argument('--height', default=512, type=int) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
args = get_args() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) |
|
model = model.eval() |
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
def get_image_paths(folder): |
|
image_extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp") |
|
files = [] |
|
for ext in image_extensions: |
|
files.extend(glob.glob(os.path.join(folder, ext))) |
|
return sorted(files) |
|
|
|
|
|
y_paths = get_image_paths(args.input_dir) |
|
|
|
|
|
def get_controlnet_models(): |
|
url = "http://localhost:7860/controlnet/model_list" |
|
|
|
temporalnet_model = None |
|
temporalnet_re = re.compile("^temporalnetversion2 \[.{8}\]") |
|
|
|
hed_model = None |
|
hed_re = re.compile("^control_.*hed.* \[.{8}\]") |
|
|
|
openpose_model = None |
|
openpose_re = re.compile("^control_.*openpose.* \[.{8}\]") |
|
|
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
models = json.loads(response.content) |
|
else: |
|
raise Exception("Unable to list models from the SD Web API! " |
|
"Is it running and is the controlnet extension installed?") |
|
|
|
for model in models['model_list']: |
|
if temporalnet_model is None and temporalnet_re.match(model): |
|
temporalnet_model = model |
|
elif hed_model is None and hed_re.match(model): |
|
hed_model = model |
|
elif openpose_model is None and openpose_re.match(model): |
|
openpose_model = model |
|
|
|
assert temporalnet_model is not None, "Unable to find the temporalnet2 model! Ensure it's copied into the stable-diffusion-webui/extensions/models directory!" |
|
assert hed_model is not None, "Unable to find the hed_model model! Ensure it's copied into the stable-diffusion-webui/extensions/models directory!" |
|
assert openpose_model is not None, "Unable to find the openpose model! Ensure it's copied into the stable-diffusion-webui/extensions/models directory!" |
|
|
|
return temporalnet_model, hed_model, openpose_model |
|
|
|
|
|
TEMPORALNET_MODEL, HED_MODEL, OPENPOSE_MODEL = get_controlnet_models() |
|
|
|
|
|
def send_request(last_image_path, optical_flow_path,current_image_path): |
|
url = "http://localhost:7860/sdapi/v1/img2img" |
|
|
|
with open(last_image_path, "rb") as b: |
|
last_image_encoded = base64.b64encode(b.read()).decode("utf-8") |
|
|
|
|
|
last_image = cv2.imread(last_image_path) |
|
last_image = cv2.cvtColor(last_image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
flow_image = cv2.imread(optical_flow_path) |
|
flow_image = cv2.cvtColor(flow_image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
with open(current_image_path, "rb") as b: |
|
current_image = base64.b64encode(b.read()).decode("utf-8") |
|
|
|
|
|
|
|
six_channel_image = np.dstack((last_image, flow_image)) |
|
|
|
|
|
serialized_image = pickle.dumps(six_channel_image) |
|
|
|
|
|
encoded_image = base64.b64encode(serialized_image).decode('utf-8') |
|
|
|
data = { |
|
"init_images": [current_image], |
|
"inpainting_fill": 0, |
|
"inpaint_full_res": True, |
|
"inpaint_full_res_padding": 1, |
|
"inpainting_mask_invert": 1, |
|
"resize_mode": 0, |
|
"denoising_strength": 0.4, |
|
"prompt": args.prompt, |
|
"negative_prompt": args.negative_prompt, |
|
"alwayson_scripts": { |
|
"ControlNet":{ |
|
"args": [ |
|
{ |
|
"input_image": current_image, |
|
"module": "hed", |
|
"model": HED_MODEL, |
|
"weight": 0.7, |
|
"guidance": 1, |
|
"pixel_perfect": True, |
|
"resize_mode": 0, |
|
}, |
|
{ |
|
"input_image": encoded_image, |
|
"model": TEMPORALNET_MODEL, |
|
"module": "none", |
|
"weight": 0.6, |
|
"guidance": 1, |
|
|
|
"threshold_a": 64, |
|
"threshold_b": 64, |
|
"resize_mode": 0, |
|
}, |
|
{ |
|
"input_image": current_image, |
|
"model": OPENPOSE_MODEL, |
|
"module": "openpose_full", |
|
"weight": 0.7, |
|
"guidance": 1, |
|
"pixel_perfect": True, |
|
"resize_mode": 0, |
|
} |
|
|
|
|
|
] |
|
} |
|
}, |
|
"seed": 4123457655, |
|
"subseed": -1, |
|
"subseed_strength": -1, |
|
"sampler_index": "Euler a", |
|
"batch_size": 1, |
|
"n_iter": 1, |
|
"steps": 20, |
|
"cfg_scale": 6, |
|
"width": args.width, |
|
"height": args.height, |
|
"restore_faces": True, |
|
"include_init_images": True, |
|
"override_settings": {}, |
|
"override_settings_restore_afterwards": True |
|
} |
|
response = requests.post(url, json=data) |
|
if response.status_code == 200: |
|
return response.content |
|
else: |
|
try: |
|
error_data = response.json() |
|
print("Error:") |
|
print(str(error_data)) |
|
|
|
except json.JSONDecodeError: |
|
print(f"Error: Unable to parse JSON error data.") |
|
return None |
|
|
|
|
|
|
|
def infer(frameA, frameB): |
|
|
|
|
|
input_frame_1 = read_image(str(frameA), ImageReadMode.RGB) |
|
|
|
input_frame_2 = read_image(str(frameB), ImageReadMode.RGB) |
|
|
|
|
|
|
|
|
|
|
|
img1_batch = torch.stack([input_frame_1]) |
|
img2_batch = torch.stack([input_frame_2]) |
|
|
|
|
|
weights = Raft_Large_Weights.DEFAULT |
|
transforms = weights.transforms() |
|
|
|
|
|
def preprocess(img1_batch, img2_batch): |
|
img1_batch = F.resize(img1_batch, size=[512, 512]) |
|
img2_batch = F.resize(img2_batch, size=[512, 512]) |
|
return transforms(img1_batch, img2_batch) |
|
|
|
img1_batch, img2_batch = preprocess(img1_batch, img2_batch) |
|
|
|
list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) |
|
|
|
predicted_flow = list_of_flows[-1][0] |
|
opitcal_flow_path = os.path.join(args.output_dir, f"flow_{i}.png") |
|
|
|
flow_img = flow_to_image(predicted_flow).to("cpu") |
|
flow_img = F.resize(flow_img, size=[args.height, args.width]) |
|
|
|
write_jpeg(flow_img, opitcal_flow_path) |
|
|
|
return opitcal_flow_path |
|
|
|
output_images = [] |
|
output_paths = [] |
|
|
|
|
|
|
|
result = args.init_image |
|
output_image_path = os.path.join(args.output_dir, f"output_image_0.png") |
|
|
|
|
|
|
|
|
|
last_image_path = args.init_image |
|
for i in range(1, len(y_paths)): |
|
|
|
optical_flow = infer(y_paths[i - 1], y_paths[i]) |
|
|
|
|
|
result = send_request(last_image_path, optical_flow, y_paths[i]) |
|
data = json.loads(result) |
|
|
|
for j, encoded_image in enumerate(data["images"]): |
|
if j == 0: |
|
output_image_path = os.path.join(args.output_dir, f"output_image_{i}.png") |
|
last_image_path = output_image_path |
|
else: |
|
output_image_path = os.path.join(args.output_dir, f"controlnet_image_{j}_{i}.png") |
|
|
|
with open(output_image_path, "wb") as f: |
|
f.write(base64.b64decode(encoded_image)) |
|
print(f"Written data for frame {i}:") |
|
|