ElixirRod's picture
Duplicate from shi-labs/Versatile-Diffusion
3ac711b
import gradio as gr
import os
import PIL
from PIL import Image
from pathlib import Path
import numpy as np
import numpy.random as npr
from contextlib import nullcontext
import torch
import torchvision.transforms as tvtrans
from lib.cfg_helper import model_cfg_bank
from lib.model_zoo import get_model
n_sample_image_default = 2
n_sample_text_default = 4
cache_examples = True
hfm_repo_id = 'shi-labs/versatile-diffusion-model'
hfm_filename = 'pretrained_pth/vd-four-flow-v1-0-fp16.pth'
def highlight_print(info):
print('')
print(''.join(['#']*(len(info)+4)))
print('# '+info+' #')
print(''.join(['#']*(len(info)+4)))
print('')
class color_adjust(object):
def __init__(self, ref_from, ref_to):
x0, m0, std0 = self.get_data_and_stat(ref_from)
x1, m1, std1 = self.get_data_and_stat(ref_to)
self.ref_from_stat = (m0, std0)
self.ref_to_stat = (m1, std1)
self.ref_from = self.preprocess(x0).reshape(-1, 3)
self.ref_to = x1.reshape(-1, 3)
def get_data_and_stat(self, x):
if isinstance(x, str):
x = np.array(PIL.Image.open(x))
elif isinstance(x, PIL.Image.Image):
x = np.array(x)
elif isinstance(x, torch.Tensor):
x = torch.clamp(x, min=0.0, max=1.0)
x = np.array(tvtrans.ToPILImage()(x))
elif isinstance(x, np.ndarray):
pass
else:
raise ValueError
x = x.astype(float)
m = np.reshape(x, (-1, 3)).mean(0)
s = np.reshape(x, (-1, 3)).std(0)
return x, m, s
def preprocess(self, x):
m0, s0 = self.ref_from_stat
m1, s1 = self.ref_to_stat
y = ((x-m0)/s0)*s1 + m1
return y
def __call__(self, xin, keep=0, simple=False):
xin, _, _ = self.get_data_and_stat(xin)
x = self.preprocess(xin)
if simple:
y = (x*(1-keep) + xin*keep)
y = np.clip(y, 0, 255).astype(np.uint8)
return y
h, w = x.shape[:2]
x = x.reshape(-1, 3)
y = []
for chi in range(3):
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
y.append(yi)
y = np.stack(y, axis=1)
y = y.reshape(h, w, 3)
y = (y.astype(float)*(1-keep) + xin.astype(float)*keep)
y = np.clip(y, 0, 255).astype(np.uint8)
return y
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
arr = np.concatenate((arr_fo, arr_to))
min_v = arr.min() - 1e-6
max_v = arr.max() + 1e-6
min_vto = arr_to.min() - 1e-6
max_vto = arr_to.max() + 1e-6
xs = np.array(
[min_v + (max_v - min_v) * i / n for i in range(n + 1)])
hist_fo, _ = np.histogram(arr_fo, xs)
hist_to, _ = np.histogram(arr_to, xs)
xs = xs[:-1]
# compute probability distribution
cum_fo = np.cumsum(hist_fo)
cum_to = np.cumsum(hist_to)
d_fo = cum_fo / cum_fo[-1]
d_to = cum_to / cum_to[-1]
# transfer
t_d = np.interp(d_fo, d_to, xs)
t_d[d_fo <= d_to[ 0]] = min_vto
t_d[d_fo >= d_to[-1]] = max_vto
arr_out = np.interp(arr_in, xs, t_d)
return arr_out
class vd_inference(object):
def __init__(self, pth=None, hfm_repo=None, fp16=False, device=0):
cfgm_name = 'vd_noema'
cfgm = model_cfg_bank()('vd_noema')
net = get_model()(cfgm)
if fp16:
highlight_print('Running in FP16')
net.clip.fp16 = True
net = net.half()
if pth is not None:
sd = torch.load(pth, map_location='cpu')
print('Load pretrained weight from {}'.format(pth))
else:
from huggingface_hub import hf_hub_download
temppath = hf_hub_download(hfm_repo[0], hfm_repo[1])
sd = torch.load(temppath, map_location='cpu')
print('Load pretrained weight from {}/{}'.format(*hfm_repo))
net.load_state_dict(sd, strict=False)
net.to(device)
self.device = device
self.model_name = cfgm_name
self.net = net
self.fp16 = fp16
from lib.model_zoo.ddim_vd import DDIMSampler_VD
self.sampler = DDIMSampler_VD(net)
def regularize_image(self, x):
BICUBIC = PIL.Image.Resampling.BICUBIC
if isinstance(x, str):
x = Image.open(x).resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, PIL.Image.Image):
x = x.resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, np.ndarray):
x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC)
x = tvtrans.ToTensor()(x)
elif isinstance(x, torch.Tensor):
pass
else:
assert False, 'Unknown image type'
assert (x.shape[1]==512) & (x.shape[2]==512), \
'Wrong image size'
x = x.to(self.device)
if self.fp16:
x = x.half()
return x
def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None):
net = self.net
if xtype == 'image':
x = net.autokl_decode(z)
color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None)
color_adj_simple = (color_adj=='Simple') or color_adj=='simple'
color_adj_keep_ratio = 0.5
if color_adj_flag and (ctype=='vision'):
x_adj = []
for xi in x:
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple)
x_adj.append(xi_adj)
x = x_adj
else:
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
x = [tvtrans.ToPILImage()(xi) for xi in x]
return x
elif xtype == 'text':
prompt_temperature = 1.0
prompt_merge_same_adj_word = True
x = net.optimus_decode(z, temperature=prompt_temperature)
if prompt_merge_same_adj_word:
xnew = []
for xi in x:
xi_split = xi.split()
xinew = []
for idxi, wi in enumerate(xi_split):
if idxi!=0 and wi==xi_split[idxi-1]:
continue
xinew.append(wi)
xnew.append(' '.join(xinew))
x = xnew
return x
def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,):
net = self.net
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
if xtype == 'image':
n_samples = n_sample_image_default if n_samples is None else n_samples
elif xtype == 'text':
n_samples = n_sample_text_default if n_samples is None else n_samples
if ctype in ['prompt', 'text']:
c = net.clip_encode_text(n_samples * [cin])
u = None
if scale != 1.0:
u = net.clip_encode_text(n_samples * [""])
elif ctype in ['vision', 'image']:
cin = self.regularize_image(cin)
ctemp = cin*2 - 1
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
c = net.clip_encode_vision(ctemp)
u = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp)
u = net.clip_encode_vision(dummy)
u, c = [u.half(), c.half()] if self.fp16 else [u, c]
if xtype == 'image':
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype=xtype, ctype=ctype,
eta=ddim_eta,
verbose=False,)
x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin)
return x
elif xtype == 'text':
n = 768
shape = [n_samples, n]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype=xtype, ctype=ctype,
eta=ddim_eta,
verbose=False,)
x = self.decode(z, xtype, ctype)
return x
def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
n_samples = n_sample_image_default if n_samples is None else n_samples
cin = self.regularize_image(cin)
ctemp = cin*2 - 1
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
c = net.clip_encode_vision(ctemp)
u = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp)
u = net.clip_encode_vision(dummy)
u, c = [u.half(), c.half()] if self.fp16 else [u, c]
if level == 0:
pass
else:
c_glb = c[:, 0:1]
c_loc = c[:, 1: ]
u_glb = u[:, 0:1]
u_loc = u[:, 1: ]
if level == -1:
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1)
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1)
if level == -2:
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2)
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2)
if level == 1:
c_loc = self.find_low_rank(c_loc, demean=True, q=10)
u_loc = self.find_low_rank(u_loc, demean=True, q=10)
if level == 2:
c_loc = self.find_low_rank(c_loc, demean=True, q=2)
u_loc = self.find_low_rank(u_loc, demean=True, q=2)
c = torch.cat([c_glb, c_loc], dim=1)
u = torch.cat([u_glb, u_loc], dim=1)
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=c,
unconditional_guidance_scale=scale,
unconditional_conditioning=u,
xtype='image', ctype='vision',
eta=ddim_eta,
verbose=False,)
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin)
return x
def find_low_rank(self, x, demean=True, q=20, niter=10):
if demean:
x_mean = x.mean(-1, keepdim=True)
x_input = x - x_mean
else:
x_input = x
if x_input.dtype == torch.float16:
fp16 = True
x_input = x_input.float()
else:
fp16 = False
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
ss = torch.stack([torch.diag(si) for si in s])
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
if fp16:
x_lowrank = x_lowrank.half()
if demean:
x_lowrank += x_mean
return x_lowrank
def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10):
if demean:
x_mean = x.mean(-1, keepdim=True)
x_input = x - x_mean
else:
x_input = x
if x_input.dtype == torch.float16:
fp16 = True
x_input = x_input.float()
else:
fp16 = False
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
s[:, 0:q_remove] = 0
ss = torch.stack([torch.diag(si) for si in s])
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
if fp16:
x_lowrank = x_lowrank.half()
if demean:
x_lowrank += x_mean
return x_lowrank
def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
n_samples = n_sample_image_default if n_samples is None else n_samples
ctemp0 = self.regularize_image(cim)
ctemp1 = ctemp0*2 - 1
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
cim = net.clip_encode_vision(ctemp1)
uim = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp1)
uim = net.clip_encode_vision(dummy)
ctx = net.clip_encode_text(n_samples * [ctx])
utx = None
if scale != 1.0:
utx = net.clip_encode_text(n_samples * [""])
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx]
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample_dc(
steps=ddim_steps,
shape=shape,
first_conditioning=[uim, cim],
second_conditioning=[utx, ctx],
unconditional_guidance_scale=scale,
xtype='image',
first_ctype='vision',
second_ctype='prompt',
eta=ddim_eta,
verbose=False,
mixed_ratio=(1-mixing), )
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
return x
def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,):
net = self.net
scale = 7.5
sampler = self.sampler
ddim_steps = 50
ddim_eta = 0.0
prompt_temperature = 1.0
n_samples = n_sample_image_default if n_samples is None else n_samples
ctemp0 = self.regularize_image(cim)
ctemp1 = ctemp0*2 - 1
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
cim = net.clip_encode_vision(ctemp1)
uim = None
if scale != 1.0:
dummy = torch.zeros_like(ctemp1)
uim = net.clip_encode_vision(dummy)
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim]
n = 768
shape = [n_samples, n]
zt, _ = sampler.sample(
steps=ddim_steps,
shape=shape,
conditioning=cim,
unconditional_guidance_scale=scale,
unconditional_conditioning=uim,
xtype='text', ctype='vision',
eta=ddim_eta,
verbose=False,)
ztn = net.optimus_encode([ctx_n])
ztp = net.optimus_encode([ctx_p])
ztn_norm = ztn / ztn.norm(dim=1)
zt_proj_mag = torch.matmul(zt, ztn_norm[0])
zt_perp = zt - zt_proj_mag[:, None] * ztn_norm
zt_newd = zt_perp + ztp
ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature)
ctx_new = net.clip_encode_text(ctx_new)
ctx_p = net.clip_encode_text([ctx_p])
ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1)
utx_new = net.clip_encode_text(n_samples * [""])
utx_new = torch.cat([utx_new, utx_new], dim=1)
cim_loc = cim[:, 1: ]
cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10)
cim_new = cim_loc_new
uim_new = uim[:, 1:]
h, w = [512, 512]
shape = [n_samples, 4, h//8, w//8]
z, _ = sampler.sample_dc(
steps=ddim_steps,
shape=shape,
first_conditioning=[uim_new, cim_new],
second_conditioning=[utx_new, ctx_new],
unconditional_guidance_scale=scale,
xtype='image',
first_ctype='vision',
second_ctype='prompt',
eta=ddim_eta,
verbose=False,
mixed_ratio=0.33, )
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0)
return x
vd_inference = vd_inference(hfm_repo=[hfm_repo_id, hfm_filename], fp16=True, device='cuda')
def main(mode,
image=None,
prompt=None,
nprompt=None,
pprompt=None,
color_adj=None,
disentanglement_level=None,
dual_guided_mixing=None,
seed=0,):
if seed<0:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed+100)
if mode == 'Text-to-Image':
if (prompt is None) or (prompt == ""):
return None, None
with torch.no_grad():
rv = vd_inference.inference(
xtype = 'image',
cin = prompt,
ctype = 'prompt', )
return rv, None
elif mode == 'Image-Variation':
if image is None:
return None, None
with torch.no_grad():
rv = vd_inference.inference(
xtype = 'image',
cin = image,
ctype = 'vision',
color_adj = color_adj,)
return rv, None
elif mode == 'Image-to-Text':
if image is None:
return None, None
with torch.no_grad():
rv = vd_inference.inference(
xtype = 'text',
cin = image,
ctype = 'vision',)
return None, '\n'.join(rv)
elif mode == 'Text-Variation':
if prompt is None:
return None, None
with torch.no_grad():
rv = vd_inference.inference(
xtype = 'text',
cin = prompt,
ctype = 'prompt',)
return None, '\n'.join(rv)
elif mode == 'Disentanglement':
if image is None:
return None, None
with torch.no_grad():
rv = vd_inference.application_disensemble(
cin = image,
level = disentanglement_level,
color_adj = color_adj,)
return rv, None
elif mode == 'Dual-Guided':
if (image is None) or (prompt is None) or (prompt==""):
return None, None
with torch.no_grad():
rv = vd_inference.application_dualguided(
cim = image,
ctx = prompt,
mixing = dual_guided_mixing,
color_adj = color_adj,)
return rv, None
elif mode == 'Latent-I2T2I':
if (image is None) or (nprompt is None) or (nprompt=="") \
or (pprompt is None) or (pprompt==""):
return None, None
with torch.no_grad():
rv = vd_inference.application_i2t2i(
cim = image,
ctx_n = nprompt,
ctx_p = pprompt,
color_adj = color_adj,)
return rv, None
else:
assert False, "No such mode!"
def get_instruction(mode):
t2i_instruction = ["Generate image from text prompt."]
i2i_instruction = [
"Generate image conditioned on reference image.",
"Color Calibration provide an opinion to adjust image color according to reference image.", ]
i2t_instruction = ["Generate text from reference image."]
t2t_instruction = ["Generate text from reference text prompt. (Model insufficiently trained, thus results are still experimental)"]
dis_instruction = [
"Generate a variation of reference image that disentangled for semantic or style.",
"Color Calibration provide an opinion to adjust image color according to reference image.",
"Disentanglement level controls the level of focus towards semantic (-2, -1) or style (1 2). Level 0 serves as Image-Variation.", ]
dug_instruction = [
"Generate image from dual guidance of reference image and text prompt.",
"Color Calibration provide an opinion to adjust image color according to reference image.",
"Guidance Mixing provides linear balances between image and text context. (0 towards image, 1 towards text)", ]
iti_instruction = [
"Generate image variations via image-to-text, text-latent-editing, and then text-to-image. (Still under exploration)",
"Color Calibration provide an opinion to adjust image color according to reference image.",
"Input prompt that will be substract from text/text latent code.",
"Input prompt that will be added to text/text latent code.", ]
if mode == "Text-to-Image":
return '\n'.join(t2i_instruction)
elif mode == "Image-Variation":
return '\n'.join(i2i_instruction)
elif mode == "Image-to-Text":
return '\n'.join(i2t_instruction)
elif mode == "Text-Variation":
return '\n'.join(t2t_instruction)
elif mode == "Disentanglement":
return '\n'.join(dis_instruction)
elif mode == "Dual-Guided":
return '\n'.join(dug_instruction)
elif mode == "Latent-I2T2I":
return '\n'.join(iti_instruction)
#############
# Interface #
#############
if True:
img_output = gr.Gallery(label="Image Result").style(grid=n_sample_image_default)
txt_output = gr.Textbox(lines=4, label='Text Result', visible=False)
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="position: relative; float: left; text-align: center; width: 60%; min-width:600px; height: 160px; margin: 20px 0 20px 20%;">
<h1 style="font-weight: 900; font-size: 3rem;">
Versatile Diffusion
</h1>
<br>
<h2 style="font-weight: 450; font-size: 1rem;">
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
VD can natively support image-to-text, image-variation, text-to-image, and text-variation,
and can be further extended to other applications such as
semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more.
Future versions will support more modalities such as speech, music, video and 3D.
</h2>
<br>
<h3>Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang,
and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
[<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>]
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
</h3>
</div>
<div style="position: relative; float: right; width: 19.9%; min-width:200px; margin: 20px auto;">
<img src="https://huggingface.co/spaces/shi-labs/Versatile-Diffusion/resolve/main/assets/figures/share_instruction.png">
</div>
""")
mode_input = gr.Radio([
"Text-to-Image", "Image-Variation", "Image-to-Text", "Text-Variation",
"Disentanglement", "Dual-Guided", "Latent-I2T2I"], value='Text-to-Image', label="VD Flows and Applications")
instruction = gr.Textbox(get_instruction("Text-to-Image"), label='Info')
with gr.Row():
with gr.Column():
img_input = gr.Image(label='Image Input', visible=False)
txt_input = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
ntxt_input = gr.Textbox(label='Remove Prompt', visible=False)
ptxt_input = gr.Textbox(label='Add Prompt', visible=False)
coladj_input = gr.Radio(["None", "Simple"], value='Simple', label="Color Calibration", visible=False)
dislvl_input = gr.Slider(-2, 2, value=0, step=1, label="Disentanglement level", visible=False)
dguide_input = gr.Slider(0, 1, value=0.5, step=0.01, label="Guidance Mixing", visible=False)
seed_input = gr.Number(100, label="Seed", precision=0)
btn = gr.Button("Run")
btn.click(
main,
inputs=[
mode_input,
img_input,
txt_input,
ntxt_input,
ptxt_input,
coladj_input,
dislvl_input,
dguide_input,
seed_input, ],
outputs=[img_output, txt_output])
with gr.Column():
img_output.render()
txt_output.render()
example_mode = [
"Text-to-Image",
"Image-Variation",
"Image-to-Text",
"Text-Variation",
"Disentanglement",
"Dual-Guided",
"Latent-I2T2I"]
def get_example(mode):
if mode == 'Text-to-Image':
case = [
['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23],
['a beautiful grand nebula in the universe', 24],
['heavy arms gundam penguin mech', 25],
]
elif mode == "Image-Variation":
case = [
['assets/space.jpg', 'None', 26],
['assets/train.jpg', 'Simple', 27],
]
elif mode == "Image-to-Text":
case = [
['assets/boy_and_girl.jpg' , 28],
['assets/house_by_lake.jpg', 29],
]
elif mode == "Text-Variation":
case = [
['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ' , 32],
['a beautiful grand nebula in the universe' , 33],
['heavy arms gundam penguin mech', 34],
]
elif mode == "Disentanglement":
case = [
['assets/vermeer.jpg', 'Simple', -2, 30],
['assets/matisse.jpg', 'Simple', 2, 31],
]
elif mode == "Dual-Guided":
case = [
['assets/benz.jpg', 'cyberpunk 2077', 'Simple', 0.75, 22],
['assets/vermeer.jpg', 'a girl with a diamond necklace', 'Simple', 0.66, 21],
]
elif mode == "Latent-I2T2I":
case = [
['assets/ghibli.jpg', 'white house', 'tall castle', 'Simple', 20],
['assets/matisse.jpg', 'fruits and bottles on the table', 'flowers on the table', 'Simple', 21],
]
else:
raise ValueError
case = [[mode] + casei for casei in case]
return case
def get_example_iof(mode):
if mode == 'Text-to-Image':
inps = [txt_input, seed_input]
oups = [img_output]
fn = lambda m, x, y: \
main(mode=m, prompt=x, seed=y)[0]
elif mode == "Image-Variation":
inps = [img_input, coladj_input, seed_input]
oups = [img_output]
fn = lambda m, x, y, z: \
main(mode=m, image=x, color_adj=y, seed=z)[0]
elif mode == "Image-to-Text":
inps = [img_input, seed_input]
oups = [txt_output]
fn = lambda m, x, y: \
main(mode=m, image=x, seed=y)[1]
elif mode == "Text-Variation":
inps = [txt_input, seed_input]
oups = [txt_output]
fn = lambda m, x, y: \
main(mode=m, prompt=x, seed=y)[1]
elif mode == "Disentanglement":
inps = [img_input, coladj_input, dislvl_input, seed_input]
oups = [img_output]
fn = lambda m, x, y, z, w: \
main(mode=m, image=x, color_adj=y, disentanglement_level=z, seed=w)[0]
elif mode == "Dual-Guided":
inps = [img_input, txt_input, coladj_input, dguide_input, seed_input]
oups = [img_output]
fn = lambda m, x, y, z, w, u: \
main(mode=m, image=x, prompt=y, color_adj=z, dual_guided_mixing=w, seed=u)[0]
elif mode == "Latent-I2T2I":
inps = [img_input, ntxt_input, ptxt_input, coladj_input, seed_input]
oups = [img_output]
fn = lambda m, x, y, z, w, u: \
main(mode=m, image=x, nprompt=y, pprompt=z, color_adj=w, seed=u)[0]
else:
raise ValueError
return [mode_input]+inps, oups, fn
with gr.Row():
for emode in example_mode[0:4]:
with gr.Column():
gr.Examples(
label=emode+' Examples',
examples=get_example(emode),
inputs=get_example_iof(emode)[0],
outputs=get_example_iof(emode)[1],
fn = get_example_iof(emode)[2],
cache_examples=cache_examples),
with gr.Row():
for emode in example_mode[4:7]:
with gr.Column():
gr.Examples(
label=emode+' Examples',
examples=get_example(emode),
inputs=get_example_iof(emode)[0],
outputs=get_example_iof(emode)[1],
fn = get_example_iof(emode)[2],
cache_examples=cache_examples),
mode_input.change(
fn=lambda x: gr.update(value=get_instruction(x)),
inputs=mode_input,
outputs=instruction,)
mode_input.change(
fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Text-Variation'])),
inputs=mode_input,
outputs=img_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x in ['Text-to-Image', 'Text-Variation', 'Dual-Guided'])),
inputs=mode_input,
outputs=txt_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])),
inputs=mode_input,
outputs=ntxt_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])),
inputs=mode_input,
outputs=ptxt_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Image-to-Text', 'Text-Variation'])),
inputs=mode_input,
outputs=coladj_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x=='Disentanglement')),
inputs=mode_input,
outputs=dislvl_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x=='Dual-Guided')),
inputs=mode_input,
outputs=dguide_input,)
mode_input.change(
fn=lambda x: gr.update(visible=(x not in ['Image-to-Text', 'Text-Variation'])),
inputs=mode_input,
outputs=img_output,)
mode_input.change(
fn=lambda x: gr.update(visible=(x in ['Image-to-Text', 'Text-Variation'])),
inputs=mode_input,
outputs=txt_output,)
gr.HTML(
"""
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
<h3>
<b>Caution</b>:
We would like the raise the awareness of users of this demo of its potential issues and concerns.
Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope.
In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data.
So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future.
We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
</h3>
<br>
<h3>
<b>Biases and content acknowledgement</b>:
Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content.
VD in this demo is meant only for research purposes.
</h3>
</div>
""")
# demo.launch(share=True)
demo.launch(debug=True)