Spaces:
Sleeping
Sleeping
import gradio as gr | |
import PIL.Image | |
import shlex | |
import shutil | |
import subprocess | |
from pathlib import Path | |
import os | |
import torch | |
from tqdm import tqdm | |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: | |
w, h = image.size | |
if w == h: | |
return image | |
elif w > h: | |
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) | |
new_image.paste(image, (0, (w - h) // 2)) | |
return new_image | |
else: | |
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) | |
new_image.paste(image, ((h - w) // 2, 0)) | |
return new_image | |
def train_submit( | |
prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath=None | |
): | |
if not torch.cuda.is_available(): | |
raise gr.Error('CUDA is not available.') | |
torch.cuda.empty_cache() | |
original_prompt = prompt | |
parameter_group = "cross-attn" | |
train_batch_size = 4 | |
if concept_type == 'style': | |
class_data_dir = f'./data/samples_painting/' | |
anchor_prompt = f'./assets/painting.txt' | |
openai_key = '' | |
elif concept_type == 'object': | |
os.makedirs('temp', exist_ok=True) | |
class_data_dir = f'./temp/{anchor_prompt}' | |
name = save_path.split('/')[-1] | |
prompt = f'{anchor_prompt}+{prompt}' | |
assert openai_key is not None | |
if len(openai_key.split('\n')) > 1: | |
openai_key = openai_key.split('\n') | |
with open(f'./temp/{name}.txt', 'w') as f: | |
for prompt_ in openai_key: | |
f.write(prompt_.strip()+'\n') | |
openai_key = '' | |
anchor_prompt = f'./temp/{name}.txt' | |
elif concept_type == 'memorization': | |
os.system("wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchscript.pt -P assets/") | |
os.makedirs('temp', exist_ok=True) | |
prompt = f'*+{prompt}' | |
name = save_path.split('/')[-1] | |
train_batch_size = 1 | |
lr = 5e-7 | |
parameter_group = "full-weight" | |
assert openai_key is not None | |
assert mem_impath is not None | |
if len(openai_key.split('\n')) > 1: | |
openai_key = openai_key.split('\n') | |
with open(f'./temp/{name}.txt', 'w') as f: | |
for prompt_ in openai_key: | |
f.write(prompt_.strip()+'\n') | |
openai_key = '' | |
anchor_prompt = f'./temp/{name}.txt' | |
else: | |
anchor_prompt = prompt | |
print(mem_impath) | |
image = PIL.Image.open(mem_impath[0][0].name) | |
image = pad_image(image) | |
image = image.convert('RGB') | |
mem_impath = f"./temp/{original_prompt.lower().replace(' ', '')}.jpg" | |
image.save(mem_impath, format='JPEG', quality=100) | |
class_data_dir = f"./temp/{original_prompt.lower().replace(' ', '')}" | |
command = f''' | |
accelerate launch concept-ablation-diffusers/train.py \ | |
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ | |
--output_dir={save_path} \ | |
--class_data_dir={class_data_dir} \ | |
--class_prompt="{anchor_prompt}" \ | |
--caption_target "{prompt}" \ | |
--concept_type {concept_type} \ | |
--resolution=512 \ | |
--train_batch_size={train_batch_size} \ | |
--learning_rate={lr} \ | |
--max_train_steps={iterations} \ | |
--scale_lr --hflip \ | |
--parameter_group {parameter_group} \ | |
--openai_key "{openai_key}" \ | |
--enable_xformers_memory_efficient_attention --num_class_images 500 | |
''' | |
if concept_type == 'style': | |
command += f' --noaug' | |
if concept_type == 'memorization': | |
command += f' --use_8bit_adam --with_prior_preservation --prior_loss_weight=1.0 --mem_impath {mem_impath}' | |
with open(f'{save_path}/train.sh', 'w') as f: | |
command_s = ' '.join(command.split()) | |
f.write(command_s) | |
res = subprocess.run(shlex.split(command)) | |
if res.returncode == 0: | |
result_message = 'Training Completed!' | |
else: | |
result_message = 'Training Failed!' | |
weight_paths = sorted(Path(save_path).glob('*.bin')) | |
print(weight_paths) | |
return gr.update(value=result_message), weight_paths[0] | |
def inference(model_path, prompt, n_steps, generator): | |
import sys | |
sys.path.append('concept-ablation/diffusers/.') | |
from model_pipeline import CustomDiffusionPipeline | |
import torch | |
pipe = CustomDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda") | |
image1 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0] | |
pipe.load_model(model_path) | |
image2 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0] | |
return image1, image2 |