Yardenfren commited on
Commit
ddea0a0
1 Parent(s): d1ca433

Upload 3 files

Browse files
Files changed (3) hide show
  1. app_inference.py +240 -0
  2. blora_utils.py +46 -0
  3. inf.py +121 -0
app_inference.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+ from typing import Tuple, Optional
8
+
9
+ import gradio as gr
10
+ from huggingface_hub import HfApi
11
+
12
+ from inf import InferencePipeline
13
+
14
+ SAMPLE_MODEL_IDS = [
15
+ 'lora-library/B-LoRA-teddybear',
16
+ 'lora-library/B-LoRA-bull',
17
+ 'lora-library/B-LoRA-wolf_plushie',
18
+ 'lora-library/B-LoRA-pen_sketch',
19
+ 'lora-library/B-LoRA-cartoon_line',
20
+ 'lora-library/B-LoRA-multi-dog2',
21
+ ]
22
+ css = """
23
+ body {
24
+ font-size: 30px;
25
+ }
26
+ .gr-image {
27
+ width: 512px;
28
+ height: 512px;
29
+ object-fit: contain;
30
+ margin: auto;
31
+ }
32
+
33
+ .lora-column {
34
+ display: flex;
35
+ flex-direction: column;
36
+ align-items: center; /* Center align content vertically in columns */
37
+ justify-content: center; /* Center content horizontally in columns */
38
+ }
39
+ .gr-row {
40
+ align-items: center;
41
+ justify-content: center;
42
+ margin-top: 5px;
43
+ }
44
+ """
45
+
46
+
47
+ def get_choices(hf_token):
48
+ api = HfApi(token=hf_token)
49
+ choices = [
50
+ info.modelId for info in api.list_models(author='lora-library')
51
+ ]
52
+ models_list = ['None'] + SAMPLE_MODEL_IDS + choices
53
+ return models_list
54
+
55
+
56
+ def get_image_from_card(card, model_id) -> Optional[str]:
57
+ try:
58
+ card_path = f"https://huggingface.co/{model_id}/resolve/main/"
59
+ widget = card.data.get('widget')
60
+ if widget is not None or len(widget) > 0:
61
+ output = widget[0].get('output')
62
+ if output is not None:
63
+ url = output.get('url')
64
+ if url is not None:
65
+ return card_path + url
66
+ return None
67
+ except Exception:
68
+ return None
69
+
70
+
71
+ def demo_init():
72
+ try:
73
+ choices = get_choices(app.hf_token)
74
+ content_blora = random.choice(SAMPLE_MODEL_IDS)
75
+ style_blora = random.choice(SAMPLE_MODEL_IDS)
76
+ content_blora_prompt, content_blora_image = app.load_model_info(content_blora)
77
+ style_blora_prompt, style_blora_image = app.load_model_info(style_blora)
78
+
79
+ content_lora_model_id = gr.update(choices=choices, value=content_blora)
80
+ content_prompt = gr.update(value=content_blora_prompt)
81
+ content_image = gr.update(value=content_blora_image)
82
+
83
+ style_lora_model_id = gr.update(choices=choices, value=style_blora)
84
+ style_prompt = gr.update(value=style_blora_prompt)
85
+ style_image = gr.update(value=style_blora_image)
86
+
87
+ prompt = gr.update(
88
+ value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style')
89
+
90
+ return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt
91
+
92
+ except Exception as e:
93
+ raise type(e)(f'failed to demo_init, due to: {e}')
94
+
95
+
96
+ def toggle_column(is_checked):
97
+ try:
98
+ return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS)
99
+ except Exception as e:
100
+ raise type(e)(f'failed to toggle_column, due to: {e}')
101
+
102
+
103
+ class InferenceUtil:
104
+ def __init__(self, hf_token: str | None):
105
+ self.hf_token = hf_token
106
+
107
+ def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]:
108
+ try:
109
+ try:
110
+ card = InferencePipeline.get_model_card(lora_model_id,
111
+ self.hf_token)
112
+ except Exception:
113
+ return '', None
114
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
115
+ image_url = get_image_from_card(card, lora_model_id)
116
+ return instance_prompt, image_url
117
+ except Exception as e:
118
+ raise type(e)(f'failed to load_model_info, due to: {e}')
119
+
120
+ def update_model_info(self, model_source: str):
121
+ try:
122
+ if model_source == 'None':
123
+ return '', None
124
+ else:
125
+ model_info = self.load_model_info(model_source)
126
+ new_prompt, new_image = model_info[0], model_info[1]
127
+ return new_prompt, new_image
128
+ except Exception as e:
129
+ raise type(e)(f'failed to update_model_info, due to: {e}')
130
+
131
+
132
+ def create_inference_demo(pipe, #: InferencePipeline,
133
+ hf_token: str | None = None) -> gr.Blocks:
134
+ with gr.Blocks(css=css) as demo:
135
+ with gr.Row(elem_classes="gr-row"):
136
+ with gr.Column():
137
+ with gr.Group(elem_classes="lora-column"):
138
+ gr.Markdown('## Content B-LoRA')
139
+ content_checkbox = gr.Checkbox(label='Use Content Only', value=False)
140
+ content_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
141
+ content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1)
142
+ content_image = gr.Image(label='Content Image', elem_classes="gr-image")
143
+ with gr.Column():
144
+ with gr.Group(elem_classes="lora-column"):
145
+ gr.Markdown('## Style B-LoRA')
146
+ style_checkbox = gr.Checkbox(label='Use Style Only', value=False)
147
+ style_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
148
+ style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1)
149
+ style_image = gr.Image(label='Style Image', elem_classes="gr-image")
150
+ with gr.Row(elem_classes="gr-row"):
151
+ with gr.Column():
152
+ with gr.Group():
153
+ prompt = gr.Textbox(
154
+ label='Prompt',
155
+ max_lines=1,
156
+ placeholder='Example: "A [c] in [s] style"'
157
+ )
158
+ result = gr.Image(label='Result')
159
+ with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"):
160
+ content_alpha = gr.Slider(label='Content B-LoRA alpha',
161
+ minimum=0,
162
+ maximum=2,
163
+ step=0.05,
164
+ value=1)
165
+ style_alpha = gr.Slider(label='Style B-LoRA alpha',
166
+ minimum=0,
167
+ maximum=2,
168
+ step=0.05,
169
+ value=1)
170
+ seed = gr.Slider(label='Seed',
171
+ minimum=0,
172
+ maximum=100000,
173
+ step=1,
174
+ value=8888)
175
+ num_steps = gr.Slider(label='Number of Steps',
176
+ minimum=0,
177
+ maximum=100,
178
+ step=1,
179
+ value=50)
180
+ guidance_scale = gr.Slider(label='CFG Scale',
181
+ minimum=0,
182
+ maximum=50,
183
+ step=0.1,
184
+ value=7.5)
185
+
186
+ run_button = gr.Button('Generate')
187
+ demo.load(demo_init, inputs=[],
188
+ outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt,
189
+ style_image, prompt], queue=False, show_progress="hidden")
190
+ content_lora_model_id.change(
191
+ fn=app.update_model_info,
192
+ inputs=content_lora_model_id,
193
+ outputs=[
194
+ content_prompt,
195
+ content_image,
196
+ ])
197
+ style_lora_model_id.change(
198
+ fn=app.update_model_info,
199
+ inputs=style_lora_model_id,
200
+ outputs=[
201
+ style_prompt,
202
+ style_image,
203
+ ])
204
+ style_prompt.change(
205
+ fn=lambda content_blora_prompt,
206
+ style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if style_blora_prompt else content_blora_prompt,
207
+ inputs=[content_prompt, style_prompt],
208
+ outputs=prompt,
209
+ )
210
+ content_prompt.change(
211
+ fn=lambda content_blora_prompt,
212
+ style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if content_blora_prompt else style_blora_prompt,
213
+ inputs=[content_prompt, style_prompt],
214
+ outputs=prompt,
215
+ )
216
+ content_checkbox.change(toggle_column, inputs=[content_checkbox],
217
+ outputs=[style_lora_model_id])
218
+ style_checkbox.change(toggle_column, inputs=[style_checkbox],
219
+ outputs=[content_lora_model_id])
220
+ inputs = [
221
+ content_lora_model_id,
222
+ style_lora_model_id,
223
+ prompt,
224
+ content_alpha,
225
+ style_alpha,
226
+ seed,
227
+ num_steps,
228
+ guidance_scale,
229
+ ]
230
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
231
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
232
+ return demo
233
+
234
+
235
+ if __name__ == '__main__':
236
+ hf_token = os.getenv('HF_TOKEN')
237
+ pipe = InferencePipeline(hf_token)
238
+ app = InferenceUtil(hf_token)
239
+ demo = create_inference_demo(pipe, hf_token)
240
+ demo.queue(max_size=10).launch(share=False)
blora_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ BLOCKS = {
4
+ 'content': ['unet.up_blocks.0.attentions.0'],
5
+ 'style': ['unet.up_blocks.0.attentions.1'],
6
+ }
7
+
8
+
9
+ def is_belong_to_blocks(key, blocks):
10
+ try:
11
+ for g in blocks:
12
+ if g in key:
13
+ return True
14
+ return False
15
+ except Exception as e:
16
+ raise type(e)(f'failed to is_belong_to_block, due to: {e}')
17
+
18
+
19
+ def filter_lora(state_dict, blocks_):
20
+ try:
21
+ return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
22
+ except Exception as e:
23
+ raise type(e)(f'failed to filter_lora, due to: {e}')
24
+
25
+
26
+ def scale_lora(state_dict, alpha):
27
+ try:
28
+ return {k: v * alpha for k, v in state_dict.items()}
29
+ except Exception as e:
30
+ raise type(e)(f'failed to scale_lora, due to: {e}')
31
+
32
+
33
+ def get_target_modules(unet, blocks=None):
34
+ try:
35
+ if not blocks:
36
+ blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
37
+
38
+ attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
39
+ is_belong_to_blocks(attn_processor_name, blocks)]
40
+
41
+ target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
42
+ return target_modules
43
+ except Exception as e:
44
+ raise type(e)(f'failed to get_target_modules, due to: {e}')
45
+
46
+
inf.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import StableDiffusionXLPipeline
10
+ from huggingface_hub import ModelCard
11
+
12
+ from blora_utils import BLOCKS, filter_lora, scale_lora
13
+
14
+
15
+ class InferencePipeline:
16
+ def __init__(self, hf_token: str | None = None):
17
+ self.hf_token = hf_token
18
+ self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ if self.device.type == 'cpu':
22
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
23
+ self.base_model_id, use_auth_token=self.hf_token, cache_dir='./cache')
24
+ else:
25
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
26
+ self.base_model_id,
27
+ torch_dtype=torch.float16,
28
+ use_auth_token=self.hf_token)
29
+ self.pipe = self.pipe.to(self.device)
30
+ self.content_lora_model_id = None
31
+ self.style_lora_model_id = None
32
+
33
+ def clear(self) -> None:
34
+ self.content_lora_model_id = None
35
+ self.style_lora_model_id = None
36
+ del self.pipe
37
+ self.pipe = None
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+ def load_b_lora_to_unet(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
42
+ style_alpha: float) -> None:
43
+ try:
44
+ # Get Content B-LoRA SD
45
+ if content_lora_model_id:
46
+ content_B_LoRA_sd, _ = self.pipe.lora_state_dict(content_lora_model_id, use_auth_token=self.hf_token)
47
+ content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
48
+ content_B_LoRA = scale_lora(content_B_LoRA, content_alpha)
49
+ else:
50
+ content_B_LoRA = {}
51
+
52
+ # Get Style B-LoRA SD
53
+ if style_lora_model_id:
54
+ style_B_LoRA_sd, _ = self.pipe.lora_state_dict(style_lora_model_id, use_auth_token=self.hf_token)
55
+ style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
56
+ style_B_LoRA = scale_lora(style_B_LoRA, style_alpha)
57
+ else:
58
+ style_B_LoRA = {}
59
+
60
+ # Merge B-LoRAs SD
61
+ res_lora = {**content_B_LoRA, **style_B_LoRA}
62
+
63
+ # Load
64
+ self.pipe.load_lora_into_unet(res_lora, None, self.pipe.unet)
65
+ except Exception as e:
66
+ raise type(e)(f'failed to load_b_lora_to_unet, due to: {e}')
67
+
68
+ @staticmethod
69
+ def check_if_model_is_local(lora_model_id: str) -> bool:
70
+ return pathlib.Path(lora_model_id).exists()
71
+
72
+ @staticmethod
73
+ def get_model_card(model_id: str,
74
+ hf_token: str | None = None) -> ModelCard:
75
+ if InferencePipeline.check_if_model_is_local(model_id):
76
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
77
+ else:
78
+ card_path = model_id
79
+ return ModelCard.load(card_path, token=hf_token)
80
+
81
+ @staticmethod
82
+ def get_base_model_info(lora_model_id: str,
83
+ hf_token: str | None = None) -> str:
84
+ card = InferencePipeline.get_model_card(lora_model_id, hf_token)
85
+ return card.data.base_model
86
+
87
+ def load_pipe(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
88
+ style_alpha: float) -> None:
89
+ if content_lora_model_id == self.content_lora_model_id and style_lora_model_id == self.style_lora_model_id:
90
+ return
91
+ self.pipe.unload_lora_weights()
92
+
93
+ self.load_b_lora_to_unet(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
94
+
95
+ self.content_lora_model_id = content_lora_model_id
96
+ self.style_lora_model_id = style_lora_model_id
97
+
98
+ def run(
99
+ self,
100
+ content_lora_model_id: str,
101
+ style_lora_model_id: str,
102
+ prompt: str,
103
+ content_alpha: float,
104
+ style_alpha: float,
105
+ seed: int,
106
+ n_steps: int,
107
+ guidance_scale: float,
108
+ ) -> PIL.Image.Image:
109
+ if not torch.cuda.is_available():
110
+ raise gr.Error('CUDA is not available.')
111
+
112
+ self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
113
+
114
+ generator = torch.Generator(device=self.device).manual_seed(seed)
115
+ out = self.pipe(
116
+ prompt,
117
+ num_inference_steps=n_steps,
118
+ guidance_scale=guidance_scale,
119
+ generator=generator,
120
+ ) # type: ignore
121
+ return out.images[0]