import os import argparse import gradio as gr import numpy as np import torch import torchvision.transforms as T from clip_interrogator import Config, Interrogator from diffusers import StableDiffusionPipeline from transformers import file_utils from ditail import DitailDemo, seed_everything BASE_MODEL = { 'sd1.5': 'runwayml/stable-diffusion-v1-5', 'realistic vision': 'stablediffusionapi/realistic-vision-v51', 'pastel mix (anime)': 'stablediffusionapi/pastel-mix-stylized-anime', # 'chaos (abstract)': 'MAPS-research/Chaos3.0', } # LoRA trigger words LORA_TRIGGER_WORD = { 'none': [], 'film': ['film overlay', 'film grain'], 'snow': ['snow'], 'flat': ['sdh', 'flat illustration'], 'minecraft': ['minecraft square style', 'cg, computer graphics'], 'animeoutline': ['lineart', 'monochrome'], 'impressionism': ['impressionist', 'in the style of Monet'], 'pop': ['POP ART'], 'shinkai_makoto': ['shinkai makoto', 'kimi no na wa.', 'tenki no ko', 'kotonoha no niwa'], } METADATA_TO_SHOW = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega'] class WebApp(): def __init__(self, debug_mode=False): if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" self.args_base = { "seed": 42, "device": self.device, "output_dir": "output_demo", "caption_model_name": "blip-large", "clip_model_name": "ViT-L-14/openai", "inv_model": "stablediffusionapi/realistic-vision-v51", "spl_model": "runwayml/stable-diffusion-v1-5", "inv_steps": 50, "spl_steps": 50, "img": None, "pos_prompt": '', "neg_prompt": 'worst quality, blurry, NSFW', "alpha": 3.0, "beta": 0.5, "omega": 15, "mask": None, "lora": "none", "lora_dir": "./ditail/lora", "lora_scale": 0.7, "no_injection": False, } self.args_input = {} # for gr.components only self.gr_loras = list(LORA_TRIGGER_WORD.keys()) self.gtag = os.environ.get('GTag') self.ga_script = f""" """ self.ga_load = f""" function() {{ window.dataLayer = window.dataLayer || []; function gtag(){{dataLayer.push(arguments);}} gtag('js', new Date()); gtag('config', '{self.gtag}'); }} """ # # pre-download base model for better user experience # self._preload_pipeline() self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed if not self.debug_mode and self.device=="cuda": self.init_interrogator() def init_interrogator(self): cache_path = os.environ.get('HF_HOME') # print(f"Intended cache dir: {cache_path}") config = Config() config.cache_path = cache_path config.clip_model_path = cache_path config.clip_model_name = self.args_base['clip_model_name'] config.caption_model_name = self.args_base['caption_model_name'] self.ci = Interrogator(config) self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024 self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024 # print(f"HF cache dir: {file_utils.default_cache_path}") def _preload_pipeline(self): for model in BASE_MODEL.values(): pipe = StableDiffusionPipeline.from_pretrained( model, torch_dtype=torch.float16 ).to(self.args_base['device']) pipe = None def title(self): gr.HTML( """

Diffusion Cocktail 🍸: Fused Generation from Diffusion Models

arXiv Paper Project Page GitHub Code
""" ) def device_requirements(self): gr.Markdown( """

Attention: The demo doesn't work in this space running on CPU only. \ Please duplicate and upgrade to a private "T4 medium" GPU.

""" ) gr.DuplicateButton(size='lg', scale=1, variant='primary') def get_image(self): self.args_input['img'] = gr.Image(label='content image', type='pil', show_share_button=False, elem_classes="input_image") def get_prompts(self): generate_prompt = gr.Checkbox(label='generate prompt with clip', value=True) self.args_input['pos_prompt'] = gr.Textbox(label='prompt') # event listeners self.args_input['img'].upload(self._interrogate_image, inputs=[self.args_input['img'], generate_prompt], outputs=[self.args_input['pos_prompt']]) generate_prompt.change(self._interrogate_image, inputs=[self.args_input['img'], generate_prompt], outputs=[self.args_input['pos_prompt']]) def _interrogate_image(self, image, generate_prompt): if hasattr(self, 'ci') and image is not None and generate_prompt: return self.ci.interrogate_fast(image).split(',')[0].replace('arafed', '') else: return '' def get_base_model(self): self.args_input['spl_model'] = gr.Radio(choices=list(BASE_MODEL.keys()), value=list(BASE_MODEL.keys())[2], label='target base model') def get_lora(self, num_cols=3): self.args_input['lora'] = gr.State('none') self.lora_gallery = gr.Gallery(label='target LoRA (optional)', columns=num_cols, value=[(os.path.join(self.args_base['lora_dir'], f"{lora}.jpeg"), lora) for lora in self.gr_loras], allow_preview=False, show_share_button=False) self.lora_gallery.select(self._update_lora_selection, inputs=[], outputs=[self.args_input['lora']]) def _update_lora_selection(self, selected_state: gr.SelectData): return self.gr_loras[selected_state.index] def get_params(self): with gr.Row(): with gr.Column(): self.args_input['inv_model'] = gr.Radio(choices=list(BASE_MODEL.keys()), value=list(BASE_MODEL.keys())[1], label='inversion base model') self.args_input['neg_prompt'] = gr.Textbox(label='negative prompt', value=self.args_base['neg_prompt']) self.args_input['alpha'] = gr.Number(label='positive prompt scaling weight (alpha)', value=self.args_base['alpha'], interactive=True) self.args_input['beta'] = gr.Number(label='negative prompt scaling weight (beta)', value=self.args_base['beta'], interactive=True) with gr.Column(): self.args_input['omega'] = gr.Slider(label='cfg', value=self.args_base['omega'], maximum=25, interactive=True) self.args_input['inv_steps'] = gr.Slider(minimum=1, maximum=100, label='edit steps', interactive=True, value=self.args_base['inv_steps'], step=1) self.args_input['spl_steps'] = gr.Slider(minimum=1, maximum=100, label='sample steps', interactive=False, value=self.args_base['spl_steps'], step=1, visible=False) # sync inv_steps with spl_steps self.args_input['inv_steps'].change(lambda x: x, inputs=self.args_input['inv_steps'], outputs=self.args_input['spl_steps']) self.args_input['lora_scale'] = gr.Slider(minimum=0, maximum=1, label='LoRA scale', value=0.7) self.args_input['seed'] = gr.Number(label='seed', value=self.args_base['seed'], interactive=True, precision=0, step=1) def run_ditail(self, *values): gr_args = self.args_base.copy() # print(self.args_input.keys()) for k, v in zip(list(self.args_input.keys()), values): gr_args[k] = v # quick fix for example gr_args['lora'] = 'none' if not isinstance(gr_args['lora'], str) else gr_args['lora'] print('selected lora: ', gr_args['lora']) # map inversion model to url gr_args['pos_prompt'] = ', '.join(LORA_TRIGGER_WORD.get(gr_args['lora'], [])+[gr_args['pos_prompt']]) gr_args['inv_model'] = BASE_MODEL[gr_args['inv_model']] gr_args['spl_model'] = BASE_MODEL[gr_args['spl_model']] print('selected model: ', gr_args['inv_model'], gr_args['spl_model']) seed_everything(gr_args['seed']) ditail = DitailDemo(gr_args) args_to_show = {} for key in METADATA_TO_SHOW: args_to_show[key] = gr_args[key] img = ditail.run_ditail() # reset ditail to free memory usage ditail = None return img, args_to_show # def run_example(self, img, prompt, inv_model, spl_model, lora): # return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model) def run_example(self, *values): gr_args = self.args_base.copy() for k, v in zip(['img', 'pos_prompt', 'inv_model', 'spl_model', 'lora'], values): gr_args[k] = v args_to_show = {} for key in METADATA_TO_SHOW: args_to_show[key] = gr_args[key] img = os.path.join(os.path.dirname(__file__), "example", "Cocktail_impression.jpg") # self.lora_gallery.selected_index = self.gr_loras.index(gr_args['lora']) return img, args_to_show def show_credits(self): gr.Markdown( """ ### Model Credits * Diffusion Models are downloaded from [huggingface](https://huggingface.co): [stable diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [realistic vision](https://huggingface.co/stablediffusionapi/realistic-vision-v51), [pastel mix](https://huggingface.co/stablediffusionapi/pastel-mix-stylized-anime) * LoRA Models are downloaded from [civitai](https://civitai.com) and [liblib](https://www.liblib.art): [film](https://civitai.com/models/90393/japan-vibes-film-color), [snow](https://www.liblib.art/modelinfo/f732b23b02f041bdb7f8f3f8a256ca8b), [flat](https://www.liblib.art/modelinfo/76dcb8b59d814960b0244849f2747a15), [minecraft](https://civitai.com/models/113741/minecraft-square-style), [animeoutline](https://civitai.com/models/16014/anime-lineart-manga-like-style), [impressionism](https://civitai.com/models/113383/y5-impressionism-style), [pop](https://civitai.com/models/161450?modelVersionId=188417), [shinkai_makoto](https://civitai.com/models/10626?modelVersionId=12610) """ ) def ui(self): with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo: self.title() if self.device == "cpu": self.device_requirements() with gr.Row(): self.get_image() with gr.Column(): self.get_prompts() self.get_base_model() self.get_lora(num_cols=3) submit_btn = gr.Button("Generate", variant='primary') if self.device == 'cpu': submit_btn.variant='secondary' with gr.Accordion("advanced options", open=False): self.get_params() with gr.Row(): with gr.Column(): output_image = gr.Image(label="output image") metadata = gr.JSON(label='metadata') submit_btn.click(self.run_ditail, inputs=list(self.args_input.values()), outputs=[output_image, metadata], scroll_to_output=True, ) with gr.Row(): cache_examples = not self.debug_mode gr.Examples( examples=[[os.path.join(os.path.dirname(__file__), "example", "Cocktail.jpg"), 'a glass of a cocktail with a lime wedge on it', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[1], 'impressionism']], inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ], fn = self.run_example, outputs=[output_image, metadata], run_on_click=True, # cache_examples=cache_examples, ) self.show_credits() demo.load(None, js=self.ga_load) return demo app = WebApp(debug_mode=False) demo = app.ui() if __name__ == "__main__": demo.launch(share=True)