skf15963's picture
Duplicate from fclong/summary
fb238e8
import os
import sys
# sys.path.insert(0, f'{PROJECT_DIR}/guided-diffusion') # 加在前面,不再读取库文件的东西。
import subprocess
import io
import torch.nn as nn
from torch.nn import functional as F
import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import math
import requests
import cv2
from resize_right import resize
from guided_diffusion.guided_diffusion.script_util import model_and_diffusion_defaults
from types import SimpleNamespace
from PIL import Image
import argparse
from guided_diffusion.guided_diffusion.unet import HFUNetModel
from tqdm.notebook import tqdm
from datetime import datetime
from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion
import clip
from transformers import BertForSequenceClassification, BertTokenizer
import gc
import random
# ======================== GLOBAL SETTING ========================
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
useCPU = False # @param {type:"boolean"}
skip_augs = False # @param{type: 'boolean'}
perlin_init = False # @param{type: 'boolean'}
use_secondary_model = False
diffusion_model = "custom"
# Dimensions must by multiples of 64.
side_x = 512
side_y = 512
diffusion_sampling_mode = 'ddim' # @param ['plms','ddim']
use_checkpoint = True # @param {type: 'boolean'}
ViTB32 = False # @param{type:"boolean"}
ViTB16 = False # @param{type:"boolean"}
ViTL14 = True # @param{type:"boolean"}
ViTL14_336px = False # @param{type:"boolean"}
RN101 = False # @param{type:"boolean"}
RN50 = False # @param{type:"boolean"}
RN50x4 = False # @param{type:"boolean"}
RN50x16 = False # @param{type:"boolean"}
RN50x64 = False # @param{type:"boolean"}
# @markdown #####**OpenCLIP settings:**
ViTB32_laion2b_e16 = False # @param{type:"boolean"}
ViTB32_laion400m_e31 = False # @param{type:"boolean"}
ViTB32_laion400m_32 = False # @param{type:"boolean"}
ViTB32quickgelu_laion400m_e31 = False # @param{type:"boolean"}
ViTB32quickgelu_laion400m_e32 = False # @param{type:"boolean"}
ViTB16_laion400m_e31 = False # @param{type:"boolean"}
ViTB16_laion400m_e32 = False # @param{type:"boolean"}
RN50_yffcc15m = False # @param{type:"boolean"}
RN50_cc12m = False # @param{type:"boolean"}
RN50_quickgelu_yfcc15m = False # @param{type:"boolean"}
RN50_quickgelu_cc12m = False # @param{type:"boolean"}
RN101_yfcc15m = False # @param{type:"boolean"}
RN101_quickgelu_yfcc15m = False # @param{type:"boolean"}
# @markdown ####**Basic Settings:**
# NOTE steps可以改这里,需要重新初始化模型,我懒得改接口了orz
steps = 100 # @param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}
tv_scale = 0 # @param{type: 'number'}
range_scale = 150 # @param{type: 'number'}
sat_scale = 0 # @param{type: 'number'}
cutn_batches = 1 # @param{type: 'number'} # NOTE 这里会对图片做数据增强,累计计算n次CLIP的梯度,以此作为guidance。
skip_augs = False # @param{type: 'boolean'}
# @markdown ####**Saving:**
intermediate_saves = 0 # @param{type: 'raw'}
intermediates_in_subfolder = True # @param{type: 'boolean'}
# perlin_init = False # @param{type: 'boolean'}
perlin_mode = 'mixed' # @param ['mixed', 'color', 'gray']
set_seed = 'random_seed' # @param{type: 'string'}
eta = 0.8 # @param{type: 'number'}
clamp_grad = True # @param{type: 'boolean'}
clamp_max = 0.05 # @param{type: 'number'}
# EXTRA ADVANCED SETTINGS:
randomize_class = True
clip_denoised = False
fuzzy_prompt = False
rand_mag = 0.05
# @markdown ---
cut_overview = "[12]*400+[4]*600" # @param {type: 'string'}
cut_innercut = "[4]*400+[12]*600" # @param {type: 'string'}
cut_ic_pow = "[1]*1000" # @param {type: 'string'}
cut_icgray_p = "[0.2]*400+[0]*600" # @param {type: 'string'}
# @markdown ####**Transformation Settings:**
use_vertical_symmetry = False # @param {type:"boolean"}
use_horizontal_symmetry = False # @param {type:"boolean"}
transformation_percent = [0.09] # @param
display_rate = 3 # @param{type: 'number'}
n_batches = 1 # @param{type: 'number'}
# @markdown If you're having issues with model downloads, check this to compare SHA's:
check_model_SHA = False # @param{type:"boolean"}
interp_spline = 'Linear' # Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"}
resume_run = False
batch_size = 1
def createPath(filepath):
os.makedirs(filepath, exist_ok=True)
def wget(url, outputdir):
res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(res)
def alpha_sigma_to_t(alpha, sigma):
return torch.atan2(sigma, alpha) * 2 / math.pi
def interp(t):
return 3 * t**2 - 2 * t ** 3
def perlin(width, height, scale=10, device=None):
gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
wx = 1 - interp(xs)
wy = 1 - interp(ys)
dots = 0
dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)
def perlin_ms(octaves, width, height, grayscale, device=None):
out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
# out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
for i in range(1 if grayscale else 3):
scale = 2 ** len(octaves)
oct_width = width
oct_height = height
for oct in octaves:
p = perlin(oct_width, oct_height, scale, device)
out_array[i] += p * oct
scale //= 2
oct_width *= 2
oct_height *= 2
return torch.cat(out_array)
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def read_image_workaround(path):
"""OpenCV reads images as BGR, Pillow saves them as RGB. Work around
this incompatibility to avoid colour inversions."""
im_tmp = cv2.imread(path)
return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
return out / out.sum()
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size
input = input.reshape([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.reshape([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, skip_augs=False):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.skip_augs = skip_augs
self.augs = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomPerspective(distortion_scale=0.4, p=0.7),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.15),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
# T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])
def forward(self, input):
input = T.Pad(input.shape[2] // 4, fill=0)(input)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
cutouts = []
for ch in range(self.cutn):
if ch > self.cutn - self.cutn // 4:
cutout = input.clone()
else:
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size / max_size), 1.))
offsetx = torch.randint(0, abs(sideX - size + 1), ())
offsety = torch.randint(0, abs(sideY - size + 1), ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
if not self.skip_augs:
cutout = self.augs(cutout)
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
del cutout
cutouts = torch.cat(cutouts, dim=0)
return cutouts
class MakeCutoutsDango(nn.Module):
def __init__(self, cut_size, args,
Overview=4,
InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2,
):
super().__init__()
self.padargs = {}
self.cutout_debug = False
self.cut_size = cut_size
self.Overview = Overview
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.augs = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation=T.InterpolationMode.BILINEAR),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])
def forward(self, input):
cutouts = []
gray = T.Grayscale(3)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
output_shape = [1, 3, self.cut_size, self.cut_size]
pad_input = F.pad(input, ((sideY - max_size) // 2, (sideY - max_size) // 2, (sideX - max_size) // 2, (sideX - max_size) // 2), **self.padargs)
cutout = resize(pad_input, out_shape=output_shape)
if self.Overview > 0:
if self.Overview <= 4:
if self.Overview >= 1:
cutouts.append(cutout)
if self.Overview >= 2:
cutouts.append(gray(cutout))
if self.Overview >= 3:
cutouts.append(TF.hflip(cutout))
if self.Overview == 4:
cutouts.append(gray(TF.hflip(cutout)))
else:
cutout = resize(pad_input, out_shape=output_shape)
for _ in range(self.Overview):
cutouts.append(cutout)
if self.cutout_debug:
# if is_colab:
# TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("/content/cutout_overview0.jpg",quality=99)
# else:
TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("cutout_overview0.jpg", quality=99)
if self.InnerCrop > 0:
for i in range(self.InnerCrop):
size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)
if self.cutout_debug:
# if is_colab:
# TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("/content/cutout_InnerCrop.jpg",quality=99)
# else:
TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("cutout_InnerCrop.jpg", quality=99)
cutouts = torch.cat(cutouts)
if skip_augs is not True:
cutouts = self.augs(cutouts)
return cutouts
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
def symmetry_transformation_fn(x):
# NOTE 强制图像对称
use_horizontal_symmetry = False
if use_horizontal_symmetry:
[n, c, h, w] = x.size()
x = torch.concat((x[:, :, :, :w // 2], torch.flip(x[:, :, :, :w // 2], [-1])), -1)
print("horizontal symmetry applied")
if use_vertical_symmetry:
[n, c, h, w] = x.size()
x = torch.concat((x[:, :, :h // 2, :], torch.flip(x[:, :, :h // 2, :], [-2])), -2)
print("vertical symmetry applied")
return x
# def split_prompts(prompts):
# prompt_series = pd.Series([np.nan for a in range(max_frames)])
# for i, prompt in prompts.items():
# prompt_series[i] = prompt
# # prompt_series = prompt_series.astype(str)
# prompt_series = prompt_series.ffill().bfill()
# return prompt_series
"""
other chaos settings
"""
# dir settings
outDirPath = f'{PROJECT_DIR}/images_out'
createPath(outDirPath)
model_path = f'{PROJECT_DIR}/models'
createPath(model_path)
# GPU setup
DEVICE = torch.device('cuda:0' if (torch.cuda.is_available() and not useCPU) else 'cpu')
print('Using device:', DEVICE)
device = DEVICE # At least one of the modules expects this name..
if not useCPU:
if torch.cuda.get_device_capability(DEVICE) == (8, 0): # A100 fix thanks to Emad
print('Disabling CUDNN for A100 gpu', file=sys.stderr)
torch.backends.cudnn.enabled = False
model_config = model_and_diffusion_defaults()
model_config.update({
'attention_resolutions': '32, 16, 8',
'class_cond': False,
'diffusion_steps': 1000, # No need to edit this, it is taken care of later.
'rescale_timesteps': True,
'timestep_respacing': 250, # No need to edit this, it is taken care of later.
'image_size': 512,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 256,
'num_head_channels': 64,
'num_res_blocks': 2,
'resblock_updown': True,
'use_checkpoint': use_checkpoint,
'use_fp16': not useCPU,
'use_scale_shift_norm': True,
})
model_default = model_config['image_size']
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
# Make folder for batch
steps_per_checkpoint = steps + 10
# Update Model Settings
timestep_respacing = f'ddim{steps}'
diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps
model_config.update({
'timestep_respacing': timestep_respacing,
'diffusion_steps': diffusion_steps,
})
start_frame = 0
print('Starting Run:')
if set_seed == 'random_seed':
random.seed()
seed = random.randint(0, 2**32)
# print(f'Using seed: {seed}')
else:
seed = int(set_seed)
args = {
# 'seed': seed,
'display_rate': display_rate,
'n_batches': n_batches,
'batch_size': batch_size,
'steps': steps,
'diffusion_sampling_mode': diffusion_sampling_mode,
# 'width_height': width_height,
'tv_scale': tv_scale,
'range_scale': range_scale,
'sat_scale': sat_scale,
'cutn_batches': cutn_batches,
# 'side_x': side_x,
# 'side_y': side_y,
'timestep_respacing': timestep_respacing,
'diffusion_steps': diffusion_steps,
'cut_overview': eval(cut_overview),
'cut_innercut': eval(cut_innercut),
'cut_ic_pow': eval(cut_ic_pow),
'cut_icgray_p': eval(cut_icgray_p),
'intermediate_saves': intermediate_saves,
'intermediates_in_subfolder': intermediates_in_subfolder,
'steps_per_checkpoint': steps_per_checkpoint,
'set_seed': set_seed,
'eta': eta,
'clamp_grad': clamp_grad,
'clamp_max': clamp_max,
'skip_augs': skip_augs,
'randomize_class': randomize_class,
'clip_denoised': clip_denoised,
'fuzzy_prompt': fuzzy_prompt,
'rand_mag': rand_mag,
'use_vertical_symmetry': use_vertical_symmetry,
'use_horizontal_symmetry': use_horizontal_symmetry,
'transformation_percent': transformation_percent,
}
args = SimpleNamespace(**args)
# ======================== GLOBAL SETTING END ========================
class Diffuser:
def __init__(self, cutom_path='IDEA-CCNL/Taiyi-Diffusion-532M-Nature'):
self.model_setup(cutom_path)
def model_setup(self, custom_path):
# LOADING MODEL
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
print(f'Prepping model...model name: {custom_path}')
__, self.diffusion = create_model_and_diffusion(**model_config)
self.model = HFUNetModel.from_pretrained(custom_path)
# total = get_parameter_num(self.model)
# print("Number of parameter: %.2fM" % (total/1e6))
# print("Number of parameter: %.2fM" % (total/1024/1024))
self.model.requires_grad_(False).eval().to(device)
for name, param in self.model.named_parameters():
if 'qkv' in name or 'norm' in name or 'proj' in name:
param.requires_grad_()
if model_config['use_fp16']:
self.model.convert_to_fp16()
print(f'Diffusion_model Loaded {diffusion_model}')
# NOTE Directly Load The Text Encoder From Hugging Face
print('Prepping model...model name: CLIP')
self.taiyi_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese")
self.taiyi_transformer = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval().to(device)
self.clip_models = []
if ViTB32:
self.clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device))
if ViTB16:
self.clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device))
if ViTL14:
self.clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device))
if ViTL14_336px:
self.clip_models.append(clip.load('ViT-L/14@336px', jit=False)[0].eval().requires_grad_(False).to(device))
print('CLIP Loaded')
# self.lpips_model = lpips.LPIPS(net='vgg').to(device)
def generate(self,
input_text_prompts=['夕阳西下'],
init_image=None,
skip_steps=10,
clip_guidance_scale=7500,
init_scale=2000,
st_dynamic_image=None,
seed=None,
side_x=512,
side_y=512,
):
seed = seed
frame_num = 0
init_image = init_image
init_scale = init_scale
skip_steps = skip_steps
loss_values = []
# if seed is not None:
# np.random.seed(seed)
# random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# target_embeds, weights = [], []
frame_prompt = input_text_prompts
print(f'Frame {frame_num} Prompt: {frame_prompt}')
model_stats = []
for clip_model in self.clip_models:
# cutn = 16
model_stat = {"clip_model": None, "target_embeds": [], "make_cutouts": None, "weights": []}
model_stat["clip_model"] = clip_model
for prompt in frame_prompt:
txt, weight = parse_prompt(prompt)
# txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
# NOTE use chinese CLIP
txt = self.taiyi_transformer(self.taiyi_tokenizer(txt, return_tensors='pt')['input_ids'].to(device)).logits
if args.fuzzy_prompt:
for i in range(25):
model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0, 1))
model_stat["weights"].append(weight)
else:
model_stat["target_embeds"].append(txt)
model_stat["weights"].append(weight)
model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"])
model_stat["weights"] = torch.tensor(model_stat["weights"], device=device)
if model_stat["weights"].sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
model_stat["weights"] /= model_stat["weights"].sum().abs()
model_stats.append(model_stat)
init = None
if init_image is not None:
# init = Image.open(fetch(init_image)).convert('RGB') # 传递的是加载好的图片。而非地址~
init = init_image
init = init.resize((side_x, side_y), Image.LANCZOS)
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
cur_t = None
def cond_fn(x, t, y=None):
with torch.enable_grad():
x_is_NaN = False
x = x.detach().requires_grad_()
n = x.shape[0]
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
out = self.diffusion.p_mean_variance(self.model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
x_in = out['pred_xstart'] * fac + x * (1 - fac)
x_in_grad = torch.zeros_like(x_in)
for model_stat in model_stats:
for i in range(args.cutn_batches):
t_int = int(t.item()) + 1 # errors on last step without +1, need to find source
# try:
input_resolution = model_stat["clip_model"].visual.input_resolution
# except:
# input_resolution = 224
cuts = MakeCutoutsDango(input_resolution,
Overview=args.cut_overview[1000 - t_int],
InnerCrop=args.cut_innercut[1000 - t_int],
IC_Size_Pow=args.cut_ic_pow[1000 - t_int],
IC_Grey_P=args.cut_icgray_p[1000 - t_int],
args=args,
)
clip_in = normalize(cuts(x_in.add(1).div(2)))
image_embeds = model_stat["clip_model"].encode_image(clip_in).float()
dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0))
dists = dists.view([args.cut_overview[1000 - t_int] + args.cut_innercut[1000 - t_int], n, -1])
losses = dists.mul(model_stat["weights"]).sum(2).mean(0)
loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch
x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches
tv_losses = tv_loss(x_in)
range_losses = range_loss(out['pred_xstart'])
sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean()
loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale
if init is not None and init_scale:
init_losses = self.lpips_model(x_in, init)
loss = loss + init_losses.sum() * init_scale
x_in_grad += torch.autograd.grad(loss, x_in)[0]
if not torch.isnan(x_in_grad).any():
grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
else:
x_is_NaN = True
grad = torch.zeros_like(x)
if args.clamp_grad and not x_is_NaN:
magnitude = grad.square().mean().sqrt()
return grad * magnitude.clamp(max=args.clamp_max) / magnitude # min=-0.02, min=-clamp_max,
return grad
if args.diffusion_sampling_mode == 'ddim':
sample_fn = self.diffusion.ddim_sample_loop_progressive
else:
sample_fn = self.diffusion.plms_sample_loop_progressive
for i in range(args.n_batches):
current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
batchBar = tqdm(range(args.n_batches), desc="Batches")
batchBar.n = i
batchBar.refresh()
gc.collect()
torch.cuda.empty_cache()
cur_t = self.diffusion.num_timesteps - skip_steps - 1
# total_steps = cur_t
if args.diffusion_sampling_mode == 'ddim':
samples = sample_fn(
self.model,
(batch_size, 3, side_y, side_x),
clip_denoised=clip_denoised,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
skip_timesteps=skip_steps,
init_image=init,
randomize_class=randomize_class,
eta=eta,
transformation_fn=symmetry_transformation_fn,
transformation_percent=args.transformation_percent
)
else:
samples = sample_fn(
self.model,
(batch_size, 3, side_y, side_x),
clip_denoised=clip_denoised,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
skip_timesteps=skip_steps,
init_image=init,
randomize_class=randomize_class,
order=2,
)
for j, sample in enumerate(samples):
cur_t -= 1
intermediateStep = False
if args.steps_per_checkpoint is not None:
if j % steps_per_checkpoint == 0 and j > 0:
intermediateStep = True
elif j in args.intermediate_saves:
intermediateStep = True
if j % args.display_rate == 0 or cur_t == -1 or intermediateStep:
for k, image in enumerate(sample['pred_xstart']):
# tqdm.write(f'Batch {i}, step {j}, output {k}:')
# percent = math.ceil(j / total_steps * 100)
if args.n_batches > 0:
filename = f'{current_time}-{parse_prompt(prompt)[0]}.png'
image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
if j % args.display_rate == 0 or cur_t == -1:
image.save(f'{outDirPath}/{filename}')
if st_dynamic_image:
st_dynamic_image.image(image, use_column_width=True)
# self.current_image = image
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="setting")
parser.add_argument('--prompt', type=str, required=True)
parser.add_argument('--text_scale', type=int, default=5000)
parser.add_argument('--model_path', type=str, default="IDEA-CCNL/Taiyi-Diffusion-532M-Nature")
parser.add_argument('--width', type=int, default=512)
parser.add_argument('--height', type=int, default=512)
user_args = parser.parse_args()
dd = Diffuser(user_args.model_path)
dd.generate([user_args.prompt],
clip_guidance_scale=user_args.text_scale,
side_x=user_args.width,
side_y=user_args.height,
)