Omnibus's picture
Update app.py
c296bf2 verified
raw
history blame
11.7 kB
import gradio as gr
from models import models
from PIL import Image
import requests
import uuid
import io
import base64
import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import make_image_grid, load_image
import uuid
base_url=f'https://omnibus-top-20-img-img-tint.hf.space/file='
loaded_model=[]
for i,model in enumerate(models):
try:
loaded_model.append(gr.load(f'models/{model}'))
except Exception as e:
print(e)
pass
print (loaded_model)
#pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, variant="fp16", use_safetensors=True).to("cpu")
#pipeline.unet = torch.compile(pipeline.unet)
grid_wide=10
def get_concat_h_cut(in1, in2):
print(in1)
print(in2)
#im1=Image.open(in1)
#im2=Image.open(in2)
im1=in1
im2=in2
dst = Image.new('RGB', (im1.width + im2.width,
min(im1.height, im2.height)))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def get_concat_v_cut(in1, in2):
print(in1)
print(in2)
im1=Image.open(in1)
im2=Image.open(in2)
#im1=in1
#im2=in2
dst = Image.new(
'RGB', (min(im1.width, im2.width), im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst
def load_model(model_drop):
pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True)
def run_dif_color(out_prompt,im_path,model_drop,tint,im_height,im_width):
uid=uuid.uuid4()
p_seed=""
out_box=[]
out_html=""
#for i,ea in enumerate(im_path.root):
cnt = 0
for hh in range(int(im_height/grid_wide)):
for b in range(int(im_width/grid_wide)):
print(f'root::{im_path.root[cnt]}')
#print(f'ea:: {ea}')
#print(f'impath:: {im_path.path}')
url = base_url+im_path.root[cnt].image.path
print(url)
myimg = cv2.imread(im_path.root[cnt].image.path)
avg_color_per_row = numpy.average(myimg, axis=0)
avg_color = numpy.average(avg_color_per_row, axis=0)
r,g,b= avg_color
color = (int(r),int(g),int(b))
print (color)
rand=random.randint(1,500)
for i in range(rand):
p_seed+=" "
try:
#model=gr.load(f'models/{model[int(model_drop)]}')
model=loaded_model[int(model_drop)]
out_img=model(out_prompt+p_seed)
print(out_img)
raw=Image.open(out_img)
raw=raw.convert('RGB')
colorize = RGBTransform().mix_with(color,factor=float(tint)).applied_to(raw)
out_box.append(colorize)
if out_box:
if len(out_box)>1:
im_roll = get_concat_v_cut(f'{out_box[0]}',f'{out_box[1]}')
im_roll.save(f'comb-{uid}-tmp.png')
for i in range(2,len(out_box)):
im_roll = get_concat_v_cut(f'comb-{uid}-tmp.png',f'{out_box[i]}')
im_roll.save(f'comb-{uid}-tmp.png')
out = f'comb-{uid}-tmp.png'
else:
#tmp_im = Image.open(out_box[0])
tmp_im = out_box[0]
tmp_im.save(f'comb-{uid}-tmp.png')
out = f'comb-{uid}-tmp.png'
except Exception as e:
print(e)
out_html=str(e)
pass
cnt+=1
yield out,out_html
def run_dif(prompt,im_path,model_drop,cnt,strength,guidance,infer,im_height,im_width):
uid=uuid.uuid4()
print(f'im_path:: {im_path}')
print(f'im_path0:: {im_path.root[0]}')
print(f'im_path0.image.path:: {im_path.root[0].image.path}')
out_box=[]
im_height=int(im_height)
im_width=int(im_width)
for i,ea in enumerate(im_path.root):
for hh in range(int(im_height/grid_wide)):
for b in range(int(im_width/grid_wide)):
print(f'root::{im_path.root[i]}')
#print(f'ea:: {ea}')
#print(f'impath:: {im_path.path}')
url = base_url+im_path.root[i].image.path
print(url)
#init_image = load_image(url)
init_image=load_image(url)
#prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
#image = pipeline(prompt, image=init_image, strength=0.8,guidance_scale=8.0,negative_prompt=negative_prompt,num_inference_steps=50).images[0]
image = pipeline(prompt, image=init_image, strength=float(strength),guidance_scale=float(guidance),num_inference_steps=int(infer)).images[0]
#make_image_grid([init_image, image], rows=1, cols=2)
out_box.append(image)
if out_box:
if len(out_box)>1:
im_roll = get_concat_v_cut(f'{out_box[0]}',f'{out_box[1]}')
im_roll.save(f'comb-{uid}-tmp.png')
for i in range(2,len(out_box)):
im_roll = get_concat_v_cut(f'comb-{uid}-tmp.png',f'{out_box[i]}')
im_roll.save(f'comb-{uid}-tmp.png')
out = f'comb-{uid}-tmp.png'
else:
#tmp_im = Image.open(out_box[0])
tmp_im = out_box[0]
tmp_im.save(f'comb-{uid}-tmp.png')
out = f'comb-{uid}-tmp.png'
yield out,""
def run_dif_old(out_prompt,model_drop,cnt):
p_seed=""
out_box=[]
out_html=""
#for i,ea in enumerate(loaded_model):
for i in range(int(cnt)):
p_seed+=" "
try:
model=loaded_model[int(model_drop)]
out_img=model(out_prompt+p_seed)
print(out_img)
out_box.append(out_img)
except Exception as e:
print(e)
out_html=str(e)
pass
yield out_box,out_html
def run_dif_og(out_prompt,model_drop,cnt):
out_box=[]
out_html=""
#for i,ea in enumerate(loaded_model):
for i in range(cnt):
try:
#print (ea)
model=loaded_model[int(model_drop)]
out_img=model(out_prompt)
print(out_img)
url=f'https://omnibus-top-20.hf.space/file={out_img}'
print(url)
uid = uuid.uuid4()
#urllib.request.urlretrieve(image, 'tmp.png')
#out=Image.open('tmp.png')
r = requests.get(url, stream=True)
if r.status_code == 200:
img_buffer = io.BytesIO(r.content)
print (f'bytes:: {io.BytesIO(r.content)}')
str_equivalent_image = base64.b64encode(img_buffer.getvalue()).decode()
img_tag = "<img src='data:image/png;base64," + str_equivalent_image + "'/>"
out_html+=f"<div class='img_class'><a href='https://huggingface.co/models/{models[i]}'>{models[i]}</a><br>"+img_tag+"</div>"
out = Image.open(io.BytesIO(r.content))
out_box.append(out)
html_out = "<div class='grid_class'>"+out_html+"</div>"
yield out_box,html_out
except Exception as e:
out_html+=str(e)
html_out = "<div class='grid_class'>"+out_html+"</div>"
yield out_box,html_out
def thread_dif(out_prompt,mod):
out_box=[]
out_html=""
#for i,ea in enumerate(loaded_model):
try:
print (ea)
model=loaded_model[int(mod)]
out_img=model(out_prompt)
print(out_img)
url=f'https://omnibus-top-20.hf.space/file={out_img}'
print(url)
uid = uuid.uuid4()
#urllib.request.urlretrieve(image, 'tmp.png')
#out=Image.open('tmp.png')
r = requests.get(url, stream=True)
if r.status_code == 200:
img_buffer = io.BytesIO(r.content)
print (f'bytes:: {io.BytesIO(r.content)}')
str_equivalent_image = base64.b64encode(img_buffer.getvalue()).decode()
img_tag = "<img src='data:image/png;base64," + str_equivalent_image + "'/>"
#out_html+=f"<div class='img_class'><a href='https://huggingface.co/models/{models[i]}'>{models[i]}</a><br>"+img_tag+"</div>"
out = Image.open(io.BytesIO(r.content))
out_box.append(out)
else:
out_html=r.status_code
html_out = "<div class='grid_class'>"+out_html+"</div>"
return out_box,html_out
except Exception as e:
out_html=str(e)
#out_html+=str(e)
html_out = "<div class='grid_class'>"+out_html+"</div>"
return out_box,html_out
css="""
.grid_class{
display:flex;
height:100%;
}
.img_class{
min-width:200px;
}
"""
def load_im(img):
im_box=[]
im = Image.open(img)
width, height = im.size
new_w=int(width/grid_wide)
new_h=new_w
w=0
h=0
newsize=(512,512)
for i in range(int(height/new_h)):
print(i)
for b in range(grid_wide):
print(b)
# Setting the points for cropped image
left = w
top = h
right = left+new_w
bottom = top+new_h
# Cropped image of above dimension
# (It will not change original image)
im1 = im.crop((left, top, right, bottom))
im1 = im1.resize(newsize)
im_box.append(im1)
w+=new_w
#yield im_box,[]
h+=new_h
w=0
yield im_box,im_box,height,width
with gr.Blocks(css=css) as app:
with gr.Row():
with gr.Column():
inp=gr.Textbox(label="Prompt")
strength=gr.Slider(label="Strength",minimum=0,maximum=1,step=0.1,value=0.2)
guidance=gr.Slider(label="Guidance",minimum=0,maximum=10,step=0.1,value=8.0)
infer=gr.Slider(label="Inference Steps",minimum=0,maximum=50,step=1,value=10)
tint = gr.Slider(label="Tint Strength", minimum=0, maximum=1, step=0.01, value=0.30)
with gr.Row():
btn=gr.Button()
stop_btn=gr.Button("Stop")
with gr.Column():
inp_im=gr.Image(type='filepath')
im_btn=gr.Button("Image Grid")
with gr.Row():
model_drop=gr.Dropdown(label="Models", choices=models, type='index', value=models[0])
cnt = gr.Number(value=1)
out_html=gr.HTML()
outp=gr.Gallery(columns=grid_wide)
#fingal=gr.Gallery(columns=grid_wide)
fin=gr.Image()
im_height=gr.Number()
im_width=gr.Number()
im_list=gr.Textbox(visible=False)
im_btn.click(load_im,inp_im,[outp,im_list,im_height,im_width])
go_btn=btn.click(run_dif_color,[inp,outp,model_drop,tint,im_height,im_width],[fin,out_html])
#go_btn = btn.click(run_dif_color,[inp,outp,model_drop,cnt,strength,guidance,infer,im_height,im_width],[fin,out_html])
stop_btn.click(None,None,None,cancels=[go_btn])
app.queue().launch()