# https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline
# https://github.com/huggingface/diffusers/blob/main/examples/community/regional_prompting_stable_diffusion.py
import gradio as gr
from diffusers.pipelines import pipeline_utils
from modules import shared, devices, scripts, processing, sd_models, prompt_parser_diffusers
def hijack_register_modules(self, **kwargs):
for name, module in kwargs.items():
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)}
elif isinstance(module, bool):
pass
else:
library, class_name = pipeline_utils._fetch_class_library_tuple(module) # pylint: disable=protected-access
register_dict = {name: (library, class_name)}
self.register_to_config(**register_dict)
setattr(self, name, module)
class Script(scripts.Script):
def title(self):
return 'Regional prompting'
def show(self, is_img2img):
return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False
def change(self, mode):
return [gr.update(visible='Col' in mode or 'Row' in mode), gr.update(visible='Prompt' in mode)]
def ui(self, _is_img2img):
with gr.Row():
gr.HTML('  Regional prompting
')
with gr.Row():
mode = gr.Radio(label='Mode', choices=['None', 'Prompt', 'Prompt EX', 'Columns', 'Rows'], value='None')
with gr.Row():
power = gr.Slider(label='Power', minimum=0, maximum=1, value=1.0, step=0.01)
threshold = gr.Textbox('', label='Prompt thresholds:', default='', visible=False)
grid = gr.Text('', label='Grid sections:', default='', visible=False)
mode.change(fn=self.change, inputs=[mode], outputs=[grid, threshold])
return mode, grid, power, threshold
def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, threshold): # pylint: disable=arguments-differ
if mode is None or mode == 'None':
return
# backup pipeline and params
orig_pipeline = shared.sd_model
orig_dtype = devices.dtype
orig_prompt_attention = shared.opts.prompt_attention
# create pipeline
if shared.sd_model_type != 'sd':
shared.log.error(f'Regional prompting: incorrect base model: {shared.sd_model.__class__.__name__}')
return
pipeline_utils.DiffusionPipeline.register_modules = hijack_register_modules
prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.orig_encode_token_ids_to_embeddings # pylint: disable=protected-access
shared.sd_model = sd_models.switch_pipe('regional_prompting_stable_diffusion', shared.sd_model)
if shared.sd_model.__class__.__name__ != 'RegionalPromptingStableDiffusionPipeline': # switch failed
shared.log.error(f'Regional prompting: not a tiling pipeline: {shared.sd_model.__class__.__name__}')
shared.sd_model = orig_pipeline
return
sd_models.set_diffuser_options(shared.sd_model)
shared.opts.data['prompt_attention'] = 'Fixed attention' # this pipeline is not compatible with embeds
processing.fix_seed(p)
# set pipeline specific params, note that standard params are applied when applicable
rp_args = {
'mode': mode.lower(),
'power': power,
}
if 'prompt' in mode.lower():
rp_args['th'] = threshold
else:
rp_args['div'] = grid
p.task_args = {
**p.task_args,
'prompt': p.prompt,
'rp_args': rp_args,
}
# run pipeline
shared.log.debug(f'Regional: args={p.task_args}')
processed: processing.Processed = processing.process_images(p) # runs processing using main loop
# restore pipeline and params
prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.compel_hijack # pylint: disable=protected-access
shared.opts.data['prompt_attention'] = orig_prompt_attention
shared.sd_model = orig_pipeline
shared.sd_model.to(orig_dtype)
return processed