Spaces:
Runtime error
Runtime error
# https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline | |
# https://github.com/huggingface/diffusers/blob/main/examples/community/regional_prompting_stable_diffusion.py | |
import gradio as gr | |
from diffusers.pipelines import pipeline_utils | |
from modules import shared, devices, scripts, processing, sd_models, prompt_parser_diffusers | |
def hijack_register_modules(self, **kwargs): | |
for name, module in kwargs.items(): | |
if module is None or isinstance(module, (tuple, list)) and module[0] is None: | |
register_dict = {name: (None, None)} | |
elif isinstance(module, bool): | |
pass | |
else: | |
library, class_name = pipeline_utils._fetch_class_library_tuple(module) # pylint: disable=protected-access | |
register_dict = {name: (library, class_name)} | |
self.register_to_config(**register_dict) | |
setattr(self, name, module) | |
class Script(scripts.Script): | |
def title(self): | |
return 'Regional prompting' | |
def show(self, is_img2img): | |
return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False | |
def change(self, mode): | |
return [gr.update(visible='Col' in mode or 'Row' in mode), gr.update(visible='Prompt' in mode)] | |
def ui(self, _is_img2img): | |
with gr.Row(): | |
gr.HTML('<a href="https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline">  Regional prompting</a><br>') | |
with gr.Row(): | |
mode = gr.Radio(label='Mode', choices=['None', 'Prompt', 'Prompt EX', 'Columns', 'Rows'], value='None') | |
with gr.Row(): | |
power = gr.Slider(label='Power', minimum=0, maximum=1, value=1.0, step=0.01) | |
threshold = gr.Textbox('', label='Prompt thresholds:', default='', visible=False) | |
grid = gr.Text('', label='Grid sections:', default='', visible=False) | |
mode.change(fn=self.change, inputs=[mode], outputs=[grid, threshold]) | |
return mode, grid, power, threshold | |
def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, threshold): # pylint: disable=arguments-differ | |
if mode is None or mode == 'None': | |
return | |
# backup pipeline and params | |
orig_pipeline = shared.sd_model | |
orig_dtype = devices.dtype | |
orig_prompt_attention = shared.opts.prompt_attention | |
# create pipeline | |
if shared.sd_model_type != 'sd': | |
shared.log.error(f'Regional prompting: incorrect base model: {shared.sd_model.__class__.__name__}') | |
return | |
pipeline_utils.DiffusionPipeline.register_modules = hijack_register_modules | |
prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.orig_encode_token_ids_to_embeddings # pylint: disable=protected-access | |
shared.sd_model = sd_models.switch_pipe('regional_prompting_stable_diffusion', shared.sd_model) | |
if shared.sd_model.__class__.__name__ != 'RegionalPromptingStableDiffusionPipeline': # switch failed | |
shared.log.error(f'Regional prompting: not a tiling pipeline: {shared.sd_model.__class__.__name__}') | |
shared.sd_model = orig_pipeline | |
return | |
sd_models.set_diffuser_options(shared.sd_model) | |
shared.opts.data['prompt_attention'] = 'Fixed attention' # this pipeline is not compatible with embeds | |
processing.fix_seed(p) | |
# set pipeline specific params, note that standard params are applied when applicable | |
rp_args = { | |
'mode': mode.lower(), | |
'power': power, | |
} | |
if 'prompt' in mode.lower(): | |
rp_args['th'] = threshold | |
else: | |
rp_args['div'] = grid | |
p.task_args = { | |
**p.task_args, | |
'prompt': p.prompt, | |
'rp_args': rp_args, | |
} | |
# run pipeline | |
shared.log.debug(f'Regional: args={p.task_args}') | |
processed: processing.Processed = processing.process_images(p) # runs processing using main loop | |
# restore pipeline and params | |
prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.compel_hijack # pylint: disable=protected-access | |
shared.opts.data['prompt_attention'] = orig_prompt_attention | |
shared.sd_model = orig_pipeline | |
shared.sd_model.to(orig_dtype) | |
return processed | |