radames commited on
Commit
6a5df00
·
1 Parent(s): cefdf75

txt2imgLoraSDXL

Browse files
frontend/src/lib/components/ImagePlayer.svelte CHANGED
@@ -32,5 +32,7 @@
32
  <slot />
33
  </div>
34
  </div>
35
- <Button on:click={takeSnapshot} disabled={!isLCMRunning} classList={'ml-auto'}>Snapshot</Button>
 
 
36
  </div>
 
32
  <slot />
33
  </div>
34
  </div>
35
+ <Button on:click={takeSnapshot} disabled={!isLCMRunning} classList={'text-sm my-1 ml-auto'}
36
+ >Snap</Button
37
+ >
38
  </div>
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -15,31 +15,10 @@
15
  $: featuredOptions = pipelineParams?.filter((e) => e?.hide !== true);
16
  </script>
17
 
18
- <div class="grid grid-cols-1 items-center gap-3">
19
- {#if featuredOptions}
20
- {#each featuredOptions as params}
21
- {#if params.field === FieldType.RANGE}
22
- <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
23
- {:else if params.field === FieldType.SEED}
24
- <SeedInput {params} bind:value={$pipelineValues[params.id]}></SeedInput>
25
- {:else if params.field === FieldType.TEXTAREA}
26
- <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
27
- {:else if params.field === FieldType.CHECKBOX}
28
- <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
29
- {:else if params.field === FieldType.SELECT}
30
- <Selectlist {params} bind:value={$pipelineValues[params.id]}></Selectlist>
31
- {/if}
32
- {/each}
33
- {/if}
34
- </div>
35
-
36
- <details>
37
- <summary class="cursor-pointer font-medium">Advanced Options</summary>
38
- <div
39
- class="grid grid-cols-1 items-center gap-3 {pipelineParams.length > 5 ? 'sm:grid-cols-2' : ''}"
40
- >
41
- {#if advanceOptions}
42
- {#each advanceOptions as params}
43
  {#if params.field === FieldType.RANGE}
44
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
45
  {:else if params.field === FieldType.SEED}
@@ -54,4 +33,29 @@
54
  {/each}
55
  {/if}
56
  </div>
57
- </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  $: featuredOptions = pipelineParams?.filter((e) => e?.hide !== true);
16
  </script>
17
 
18
+ <div class="flex flex-col gap-3">
19
+ <div class="grid grid-cols-1 items-center gap-3">
20
+ {#if featuredOptions}
21
+ {#each featuredOptions as params}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  {#if params.field === FieldType.RANGE}
23
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
24
  {:else if params.field === FieldType.SEED}
 
33
  {/each}
34
  {/if}
35
  </div>
36
+
37
+ <details>
38
+ <summary class="cursor-pointer font-medium">Advanced Options</summary>
39
+ <div
40
+ class="grid grid-cols-1 items-center gap-3 {pipelineParams.length > 5
41
+ ? 'sm:grid-cols-2'
42
+ : ''}"
43
+ >
44
+ {#if advanceOptions}
45
+ {#each advanceOptions as params}
46
+ {#if params.field === FieldType.RANGE}
47
+ <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
48
+ {:else if params.field === FieldType.SEED}
49
+ <SeedInput {params} bind:value={$pipelineValues[params.id]}></SeedInput>
50
+ {:else if params.field === FieldType.TEXTAREA}
51
+ <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
52
+ {:else if params.field === FieldType.CHECKBOX}
53
+ <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
54
+ {:else if params.field === FieldType.SELECT}
55
+ <Selectlist {params} bind:value={$pipelineValues[params.id]}></Selectlist>
56
+ {/if}
57
+ {/each}
58
+ {/if}
59
+ </div>
60
+ </details>
61
+ </div>
frontend/src/lib/components/TextArea.svelte CHANGED
@@ -8,7 +8,7 @@
8
  });
9
  </script>
10
 
11
- <div class="px-1 py-1">
12
  <label class="text-sm font-medium" for={params?.title}>
13
  {params?.title}
14
  </label>
 
8
  });
9
  </script>
10
 
11
+ <div class="">
12
  <label class="text-sm font-medium" for={params?.title}>
13
  {params?.title}
14
  </label>
frontend/src/routes/+page.svelte CHANGED
@@ -111,15 +111,13 @@
111
  <article class="my-3 grid grid-cols-1 gap-3 lg:grid-cols-2">
112
  <div>
113
  <PipelineOptions {pipelineParams}></PipelineOptions>
114
- <div class="flex gap-3">
115
- <Button on:click={toggleLcmLive} {disabled}>
116
- {#if isLCMRunning}
117
- Stop
118
- {:else}
119
- Start
120
- {/if}
121
- </Button>
122
- </div>
123
  </div>
124
  <div>
125
  <ImagePlayer>
 
111
  <article class="my-3 grid grid-cols-1 gap-3 lg:grid-cols-2">
112
  <div>
113
  <PipelineOptions {pipelineParams}></PipelineOptions>
114
+ <Button on:click={toggleLcmLive} {disabled} classList={'text-lg my-1'}>
115
+ {#if isLCMRunning}
116
+ Stop
117
+ {:else}
118
+ Start
119
+ {/if}
120
+ </Button>
 
 
121
  </div>
122
  <div>
123
  <ImagePlayer>
pipelines/controlnetLoraSDXL.py CHANGED
@@ -49,13 +49,6 @@ class Pipeline:
49
  field="textarea",
50
  id="prompt",
51
  )
52
- model_id: str = Field(
53
- "plasmo/woolitize",
54
- title="Base Model",
55
- values=list(base_models.keys()),
56
- field="select",
57
- id="model_id",
58
- )
59
  negative_prompt: str = Field(
60
  default_negative_prompt,
61
  title="Negative Prompt",
@@ -70,10 +63,10 @@ class Pipeline:
70
  4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
71
  )
72
  width: int = Field(
73
- 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
74
  )
75
  height: int = Field(
76
- 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
77
  )
78
  guidance_scale: float = Field(
79
  1.0,
@@ -212,11 +205,7 @@ class Pipeline:
212
 
213
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
214
  generator = torch.manual_seed(params.seed)
215
- print(f"Using model: {params.model_id}")
216
- # pipe = self.pipes[params.model_id]
217
 
218
- # activation_token = base_models[params.model_id]
219
- # prompt = f"{activation_token} {params.prompt}"
220
  prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
221
  [params.prompt, params.negative_prompt]
222
  )
 
49
  field="textarea",
50
  id="prompt",
51
  )
 
 
 
 
 
 
 
52
  negative_prompt: str = Field(
53
  default_negative_prompt,
54
  title="Negative Prompt",
 
63
  4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
64
  )
65
  width: int = Field(
66
+ 768, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
67
  )
68
  height: int = Field(
69
+ 768, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
70
  )
71
  guidance_scale: float = Field(
72
  1.0,
 
205
 
206
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
207
  generator = torch.manual_seed(params.seed)
 
 
208
 
 
 
209
  prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
210
  [params.prompt, params.negative_prompt]
211
  )
pipelines/txt2imgLoraSDXL.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ DiffusionPipeline,
3
+ LCMScheduler,
4
+ AutoencoderKL,
5
+ )
6
+ from compel import Compel, ReturnedEmbeddingsType
7
+ import torch
8
+
9
+ try:
10
+ import intel_extension_for_pytorch as ipex # type: ignore
11
+ except:
12
+ pass
13
+
14
+ import psutil
15
+ from config import Args
16
+ from pydantic import BaseModel, Field
17
+ from PIL import Image
18
+
19
+ controlnet_model = "diffusers/controlnet-canny-sdxl-1.0"
20
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
21
+ lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
22
+
23
+
24
+ default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
25
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
26
+
27
+
28
+ class Pipeline:
29
+ class Info(BaseModel):
30
+ name: str = "LCM+Lora+SDXL"
31
+ title: str = "Text-to-Image SDXL + LCM + LoRA"
32
+ description: str = "Generates an image from a text prompt"
33
+ input_mode: str = "text"
34
+
35
+ class InputParams(BaseModel):
36
+ prompt: str = Field(
37
+ default_prompt,
38
+ title="Prompt",
39
+ field="textarea",
40
+ id="prompt",
41
+ )
42
+ negative_prompt: str = Field(
43
+ default_negative_prompt,
44
+ title="Negative Prompt",
45
+ field="textarea",
46
+ id="negative_prompt",
47
+ hide=True,
48
+ )
49
+ seed: int = Field(
50
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
51
+ )
52
+ steps: int = Field(
53
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
54
+ )
55
+ width: int = Field(
56
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
57
+ )
58
+ height: int = Field(
59
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
60
+ )
61
+ guidance_scale: float = Field(
62
+ 1.0,
63
+ min=0,
64
+ max=20,
65
+ step=0.001,
66
+ title="Guidance Scale",
67
+ field="range",
68
+ hide=True,
69
+ id="guidance_scale",
70
+ )
71
+
72
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
73
+ vae = AutoencoderKL.from_pretrained(
74
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
75
+ )
76
+ if args.safety_checker:
77
+ self.pipe = DiffusionPipeline.from_pretrained(
78
+ model_id,
79
+ vae=vae,
80
+ )
81
+ else:
82
+ self.pipe = DiffusionPipeline.from_pretrained(
83
+ model_id,
84
+ safety_checker=None,
85
+ vae=vae,
86
+ )
87
+ # Load LCM LoRA
88
+ self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
89
+ self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
90
+ self.pipe.set_progress_bar_config(disable=True)
91
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
92
+
93
+ if psutil.virtual_memory().total < 64 * 1024**3:
94
+ self.pipe.enable_attention_slicing()
95
+
96
+ self.pipe.compel_proc = Compel(
97
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
98
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
99
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
100
+ requires_pooled=[False, True],
101
+ )
102
+
103
+ if args.torch_compile:
104
+ self.pipe.unet = torch.compile(
105
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
106
+ )
107
+ self.pipe.vae = torch.compile(
108
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
109
+ )
110
+ self.pipe(
111
+ prompt="warmup",
112
+ )
113
+
114
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
115
+ generator = torch.manual_seed(params.seed)
116
+
117
+ prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
118
+ [params.prompt, params.negative_prompt]
119
+ )
120
+ results = self.pipe(
121
+ prompt_embeds=prompt_embeds[0:1],
122
+ pooled_prompt_embeds=pooled_prompt_embeds[0:1],
123
+ negative_prompt_embeds=prompt_embeds[1:2],
124
+ negative_pooled_prompt_embeds=pooled_prompt_embeds[1:2],
125
+ generator=generator,
126
+ num_inference_steps=params.steps,
127
+ guidance_scale=params.guidance_scale,
128
+ width=params.width,
129
+ height=params.height,
130
+ output_type="pil",
131
+ )
132
+
133
+ nsfw_content_detected = (
134
+ results.nsfw_content_detected[0]
135
+ if "nsfw_content_detected" in results
136
+ else False
137
+ )
138
+ if nsfw_content_detected:
139
+ return None
140
+ result_image = results.images[0]
141
+
142
+ return result_image