radames commited on
Commit
ff9325e
1 Parent(s): 43148fd

controlnet

Browse files
app.py CHANGED
@@ -3,7 +3,7 @@ from fastapi import FastAPI
3
  from config import args
4
  from device import device, torch_dtype
5
  from app_init import init_app
6
- from user_queue import user_queue_map
7
  from util import get_pipeline_class
8
 
9
 
@@ -11,4 +11,4 @@ app = FastAPI()
11
 
12
  pipeline_class = get_pipeline_class(args.pipeline)
13
  pipeline = pipeline_class(args, device, torch_dtype)
14
- init_app(app, user_queue_map, args, pipeline)
 
3
  from config import args
4
  from device import device, torch_dtype
5
  from app_init import init_app
6
+ from user_queue import user_data_events
7
  from util import get_pipeline_class
8
 
9
 
 
11
 
12
  pipeline_class = get_pipeline_class(args.pipeline)
13
  pipeline = pipeline_class(args, device, torch_dtype)
14
+ init_app(app, user_data_events, args, pipeline)
app_init.py CHANGED
@@ -6,15 +6,16 @@ from fastapi.staticfiles import StaticFiles
6
  import logging
7
  import traceback
8
  from config import Args
9
- from user_queue import UserQueueDict
10
  import uuid
11
- import asyncio
12
  import time
13
  from PIL import Image
14
  import io
 
15
 
16
 
17
- def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
18
  app.add_middleware(
19
  CORSMiddleware,
20
  allow_origins=["*"],
@@ -27,19 +28,20 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
27
  @app.websocket("/ws")
28
  async def websocket_endpoint(websocket: WebSocket):
29
  await websocket.accept()
30
- if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size:
31
  print("Server is full")
32
  await websocket.send_json({"status": "error", "message": "Server is full"})
33
  await websocket.close()
34
  return
35
 
36
  try:
37
- uid = uuid.uuid4()
38
  print(f"New user connected: {uid}")
39
  await websocket.send_json(
40
  {"status": "success", "message": "Connected", "userId": uid}
41
  )
42
- user_queue_map[uid] = {"queue": asyncio.Queue()}
 
43
  await websocket.send_json(
44
  {"status": "start", "message": "Start Streaming", "userId": uid}
45
  )
@@ -49,40 +51,27 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
49
  traceback.print_exc()
50
  finally:
51
  print(f"User disconnected: {uid}")
52
- queue_value = user_queue_map.pop(uid, None)
53
- queue = queue_value.get("queue", None)
54
- if queue:
55
- while not queue.empty():
56
- try:
57
- queue.get_nowait()
58
- except asyncio.QueueEmpty:
59
- continue
60
 
61
  @app.get("/queue_size")
62
  async def get_queue_size():
63
- queue_size = len(user_queue_map)
64
  return JSONResponse({"queue_size": queue_size})
65
 
66
  @app.get("/stream/{user_id}")
67
  async def stream(user_id: uuid.UUID):
68
- uid = user_id
69
  try:
70
- user_queue = user_queue_map[uid]
71
- queue = user_queue["queue"]
72
 
73
  async def generate():
74
  last_prompt: str = None
75
  while True:
76
- data = await queue.get()
77
- input_image = data["image"]
78
  params = data["params"]
79
- if input_image is None:
80
- continue
81
-
82
- image = pipeline.predict(
83
- input_image,
84
- params,
85
- )
86
  if image is None:
87
  continue
88
  frame_data = io.BytesIO()
@@ -91,36 +80,31 @@ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
91
  if frame_data is not None and len(frame_data) > 0:
92
  yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
93
 
94
- await asyncio.sleep(1.0 / 120.0)
95
 
96
  return StreamingResponse(
97
  generate(), media_type="multipart/x-mixed-replace;boundary=frame"
98
  )
99
  except Exception as e:
100
- logging.error(f"Streaming Error: {e}, {user_queue_map}")
101
  traceback.print_exc()
102
  return HTTPException(status_code=404, detail="User not found")
103
 
104
  async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
105
- uid = user_id
106
- user_queue = user_queue_map[uid]
107
- queue = user_queue["queue"]
108
- if not queue:
109
  return HTTPException(status_code=404, detail="User not found")
110
  last_time = time.time()
111
  try:
112
  while True:
113
- data = await websocket.receive_bytes()
114
  params = await websocket.receive_json()
115
  params = pipeline.InputParams(**params)
116
- pil_image = Image.open(io.BytesIO(data))
117
-
118
- while not queue.empty():
119
- try:
120
- queue.get_nowait()
121
- except asyncio.QueueEmpty:
122
- continue
123
- await queue.put({"image": pil_image, "params": params})
124
  if args.timeout > 0 and time.time() - last_time > args.timeout:
125
  await websocket.send_json(
126
  {
 
6
  import logging
7
  import traceback
8
  from config import Args
9
+ from user_queue import UserDataEventMap, UserDataEvent
10
  import uuid
11
+ from asyncio import Event, sleep
12
  import time
13
  from PIL import Image
14
  import io
15
+ from types import SimpleNamespace
16
 
17
 
18
+ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
28
  @app.websocket("/ws")
29
  async def websocket_endpoint(websocket: WebSocket):
30
  await websocket.accept()
31
+ if args.max_queue_size > 0 and len(user_data_events) >= args.max_queue_size:
32
  print("Server is full")
33
  await websocket.send_json({"status": "error", "message": "Server is full"})
34
  await websocket.close()
35
  return
36
 
37
  try:
38
+ uid = str(uuid.uuid4())
39
  print(f"New user connected: {uid}")
40
  await websocket.send_json(
41
  {"status": "success", "message": "Connected", "userId": uid}
42
  )
43
+ user_data_events[uid] = UserDataEvent()
44
+ print(f"User data events: {user_data_events}")
45
  await websocket.send_json(
46
  {"status": "start", "message": "Start Streaming", "userId": uid}
47
  )
 
51
  traceback.print_exc()
52
  finally:
53
  print(f"User disconnected: {uid}")
54
+ del user_data_events[uid]
 
 
 
 
 
 
 
55
 
56
  @app.get("/queue_size")
57
  async def get_queue_size():
58
+ queue_size = len(user_data_events)
59
  return JSONResponse({"queue_size": queue_size})
60
 
61
  @app.get("/stream/{user_id}")
62
  async def stream(user_id: uuid.UUID):
63
+ uid = str(user_id)
64
  try:
 
 
65
 
66
  async def generate():
67
  last_prompt: str = None
68
  while True:
69
+ data = await user_data_events[uid].wait_for_data()
 
70
  params = data["params"]
71
+ # input_image = data["image"]
72
+ # if input_image is None:
73
+ # continue
74
+ image = pipeline.predict(params)
 
 
 
75
  if image is None:
76
  continue
77
  frame_data = io.BytesIO()
 
80
  if frame_data is not None and len(frame_data) > 0:
81
  yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
82
 
83
+ await sleep(1.0 / 120.0)
84
 
85
  return StreamingResponse(
86
  generate(), media_type="multipart/x-mixed-replace;boundary=frame"
87
  )
88
  except Exception as e:
89
+ logging.error(f"Streaming Error: {e}, {user_data_events}")
90
  traceback.print_exc()
91
  return HTTPException(status_code=404, detail="User not found")
92
 
93
  async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
94
+ uid = str(user_id)
95
+ if uid not in user_data_events:
 
 
96
  return HTTPException(status_code=404, detail="User not found")
97
  last_time = time.time()
98
  try:
99
  while True:
 
100
  params = await websocket.receive_json()
101
  params = pipeline.InputParams(**params)
102
+ params = SimpleNamespace(**params.dict())
103
+ if hasattr(params, "image"):
104
+ image_data = await websocket.receive_bytes()
105
+ pil_image = Image.open(io.BytesIO(image_data))
106
+ params.image = pil_image
107
+ user_data_events[uid].update_data({"params": params})
 
 
108
  if args.timeout > 0 and time.time() - last_time > args.timeout:
109
  await websocket.send_json(
110
  {
frontend/package-lock.json CHANGED
@@ -7,6 +7,9 @@
7
  "": {
8
  "name": "frontend",
9
  "version": "0.0.1",
 
 
 
10
  "devDependencies": {
11
  "@sveltejs/adapter-auto": "^2.0.0",
12
  "@sveltejs/adapter-static": "^2.0.3",
@@ -3035,6 +3038,11 @@
3035
  "queue-microtask": "^1.2.2"
3036
  }
3037
  },
 
 
 
 
 
3038
  "node_modules/sade": {
3039
  "version": "1.8.1",
3040
  "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
 
7
  "": {
8
  "name": "frontend",
9
  "version": "0.0.1",
10
+ "dependencies": {
11
+ "rvfc-polyfill": "^1.0.7"
12
+ },
13
  "devDependencies": {
14
  "@sveltejs/adapter-auto": "^2.0.0",
15
  "@sveltejs/adapter-static": "^2.0.3",
 
3038
  "queue-microtask": "^1.2.2"
3039
  }
3040
  },
3041
+ "node_modules/rvfc-polyfill": {
3042
+ "version": "1.0.7",
3043
+ "resolved": "https://registry.npmjs.org/rvfc-polyfill/-/rvfc-polyfill-1.0.7.tgz",
3044
+ "integrity": "sha512-seBl7J1J3/k0LuzW2T9fG6JIOpni5AbU+/87LA+zTYKgTVhsfShmS8K/yOo1eeEjGJHnAdkVAUUM+PEjN9Mpkw=="
3045
+ },
3046
  "node_modules/sade": {
3047
  "version": "1.8.1",
3048
  "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
frontend/package.json CHANGED
@@ -33,5 +33,8 @@
33
  "typescript": "^5.0.0",
34
  "vite": "^4.4.2"
35
  },
36
- "type": "module"
 
 
 
37
  }
 
33
  "typescript": "^5.0.0",
34
  "vite": "^4.4.2"
35
  },
36
+ "type": "module",
37
+ "dependencies": {
38
+ "rvfc-polyfill": "^1.0.7"
39
+ }
40
  }
frontend/src/lib/components/Button.svelte CHANGED
@@ -1,8 +1,9 @@
1
  <script lang="ts">
2
  export let classList: string = '';
 
3
  </script>
4
 
5
- <button class="button {classList}" on:click>
6
  <slot />
7
  </button>
8
 
 
1
  <script lang="ts">
2
  export let classList: string = '';
3
+ export let disabled: boolean = false;
4
  </script>
5
 
6
+ <button class="button {classList}" on:click {disabled}>
7
  <slot />
8
  </button>
9
 
frontend/src/lib/components/Checkbox.svelte ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import type { FieldProps } from '$lib/types';
3
+ export let value = false;
4
+ export let params: FieldProps;
5
+ </script>
6
+
7
+ <div class="grid max-w-md grid-cols-4 items-center justify-items-start gap-3">
8
+ <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
9
+ <input bind:checked={value} type="checkbox" id={params.id} class="cursor-pointer" />
10
+ </div>
frontend/src/lib/components/ImagePlayer.svelte CHANGED
@@ -1,12 +1,18 @@
1
  <script lang="ts">
 
 
 
 
 
2
  </script>
3
 
4
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
5
  <!-- svelte-ignore a11y-missing-attribute -->
6
- <img
7
- class="aspect-square w-full rounded-lg"
8
- src=""
9
- />
 
10
  <div class="absolute left-0 top-0 aspect-square w-1/4">
11
  <div class="relative z-10 aspect-square w-full object-cover">
12
  <slot />
 
1
  <script lang="ts">
2
+ import { isLCMRunning, lcmLiveState, lcmLiveActions } from '$lib/lcmLive';
3
+ import { onFrameChangeStore } from '$lib/mediaStream';
4
+ import { PUBLIC_BASE_URL } from '$env/static/public';
5
+
6
+ $: streamId = $lcmLiveState.streamId;
7
  </script>
8
 
9
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
10
  <!-- svelte-ignore a11y-missing-attribute -->
11
+ {#if $isLCMRunning}
12
+ <img class="aspect-square w-full rounded-lg" src={PUBLIC_BASE_URL + '/stream/' + streamId} />
13
+ {:else}
14
+ <div class="aspect-square w-full rounded-lg" />
15
+ {/if}
16
  <div class="absolute left-0 top-0 aspect-square w-1/4">
17
  <div class="relative z-10 aspect-square w-full object-cover">
18
  <slot />
frontend/src/lib/components/InputRange.svelte CHANGED
@@ -8,14 +8,14 @@
8
  });
9
  </script>
10
 
11
- <div class="grid max-w-md grid-cols-4 items-center gap-3">
12
- <label class="text-sm font-medium" for="guidance-scale">{params?.title}</label>
13
  <input
14
- class="col-span-2"
15
  bind:value
16
  type="range"
17
- id="guidance-scale"
18
- name="guidance-scale"
19
  min={params?.min}
20
  max={params?.max}
21
  step={params?.step ?? 1}
@@ -24,6 +24,27 @@
24
  type="number"
25
  step={params?.step ?? 1}
26
  bind:value
27
- class="rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-bold dark:text-black"
28
  />
29
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  });
9
  </script>
10
 
11
+ <div class="grid grid-cols-4 items-center gap-3">
12
+ <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
13
  <input
14
+ class="col-span-2 h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-300 dark:bg-gray-500"
15
  bind:value
16
  type="range"
17
+ id={params.id}
18
+ name={params.id}
19
  min={params?.min}
20
  max={params?.max}
21
  step={params?.step ?? 1}
 
24
  type="number"
25
  step={params?.step ?? 1}
26
  bind:value
27
+ class="rounded-md border px-1 py-1 text-center text-xs font-bold dark:text-black"
28
  />
29
  </div>
30
+ <!--
31
+ <style lang="postcss" scoped>
32
+ input[type='range']::-webkit-slider-runnable-track {
33
+ @apply h-2 cursor-pointer rounded-lg dark:bg-gray-50;
34
+ }
35
+ input[type='range']::-webkit-slider-thumb {
36
+ @apply cursor-pointer rounded-lg dark:bg-gray-50;
37
+ }
38
+ input[type='range']::-moz-range-track {
39
+ @apply cursor-pointer rounded-lg dark:bg-gray-50;
40
+ }
41
+ input[type='range']::-moz-range-thumb {
42
+ @apply cursor-pointer rounded-lg dark:bg-gray-50;
43
+ }
44
+ input[type='range']::-ms-track {
45
+ @apply cursor-pointer rounded-lg dark:bg-gray-50;
46
+ }
47
+ input[type='range']::-ms-thumb {
48
+ @apply cursor-pointer rounded-lg dark:bg-gray-50;
49
+ }
50
+ </style> -->
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -5,6 +5,7 @@
5
  import InputRange from './InputRange.svelte';
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
 
8
 
9
  export let pipelineParams: FieldProps[];
10
  export let pipelineValues = {} as any;
@@ -17,11 +18,13 @@
17
  {#if featuredOptions}
18
  {#each featuredOptions as params}
19
  {#if params.field === FieldType.range}
20
- <InputRange {params} bind:value={pipelineValues[params.title]}></InputRange>
21
  {:else if params.field === FieldType.seed}
22
- <SeedInput bind:value={pipelineValues[params.title]}></SeedInput>
23
  {:else if params.field === FieldType.textarea}
24
- <TextArea {params} bind:value={pipelineValues[params.title]}></TextArea>
 
 
25
  {/if}
26
  {/each}
27
  {/if}
@@ -29,15 +32,17 @@
29
 
30
  <details open>
31
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
32
- <div class="flex flex-col gap-3 py-3">
33
  {#if advanceOptions}
34
  {#each advanceOptions as params}
35
  {#if params.field === FieldType.range}
36
- <InputRange {params} bind:value={pipelineValues[params.title]}></InputRange>
37
  {:else if params.field === FieldType.seed}
38
- <SeedInput bind:value={pipelineValues[params.title]}></SeedInput>
39
  {:else if params.field === FieldType.textarea}
40
- <TextArea {params} bind:value={pipelineValues[params.title]}></TextArea>
 
 
41
  {/if}
42
  {/each}
43
  {/if}
 
5
  import InputRange from './InputRange.svelte';
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
8
+ import Checkbox from './Checkbox.svelte';
9
 
10
  export let pipelineParams: FieldProps[];
11
  export let pipelineValues = {} as any;
 
18
  {#if featuredOptions}
19
  {#each featuredOptions as params}
20
  {#if params.field === FieldType.range}
21
+ <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
22
  {:else if params.field === FieldType.seed}
23
+ <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
24
  {:else if params.field === FieldType.textarea}
25
+ <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
26
+ {:else if params.field === FieldType.checkbox}
27
+ <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
28
  {/if}
29
  {/each}
30
  {/if}
 
32
 
33
  <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
+ <div class="grid grid-cols-1 items-center gap-3 sm:grid-cols-2">
36
  {#if advanceOptions}
37
  {#each advanceOptions as params}
38
  {#if params.field === FieldType.range}
39
+ <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
40
  {:else if params.field === FieldType.seed}
41
+ <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
42
  {:else if params.field === FieldType.textarea}
43
+ <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
44
+ {:else if params.field === FieldType.checkbox}
45
+ <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
46
  {/if}
47
  {/each}
48
  {/if}
frontend/src/lib/components/SeedInput.svelte CHANGED
@@ -16,5 +16,5 @@
16
  name="seed"
17
  class="col-span-2 rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
18
  />
19
- <Button on:click={randomize}>Randomize</Button>
20
  </div>
 
16
  name="seed"
17
  class="col-span-2 rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
18
  />
19
+ <Button on:click={randomize}>Rand</Button>
20
  </div>
frontend/src/lib/components/VideoInput.svelte CHANGED
@@ -1,4 +1,73 @@
1
  <script lang="ts">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  </script>
3
 
4
- <video playsinline autoplay muted loop />
 
 
 
 
 
 
 
 
 
1
  <script lang="ts">
2
+ import 'rvfc-polyfill';
3
+ import { onMount, onDestroy } from 'svelte';
4
+ import {
5
+ mediaStreamState,
6
+ mediaStreamActions,
7
+ isMediaStreaming,
8
+ MediaStreamStatus,
9
+ onFrameChangeStore
10
+ } from '$lib/mediaStream';
11
+
12
+ $: mediaStream = $mediaStreamState.mediaStream;
13
+
14
+ let videoEl: HTMLVideoElement;
15
+ let videoFrameCallbackId: number;
16
+ const WIDTH = 512;
17
+ const HEIGHT = 512;
18
+
19
+ onDestroy(() => {
20
+ if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
21
+ });
22
+
23
+ function srcObject(node: HTMLVideoElement, stream: MediaStream) {
24
+ node.srcObject = stream;
25
+ return {
26
+ update(newStream: MediaStream) {
27
+ if (node.srcObject != newStream) {
28
+ node.srcObject = newStream;
29
+ }
30
+ }
31
+ };
32
+ }
33
+ async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
34
+ const blob = await grapBlobImg();
35
+ onFrameChangeStore.set({ now, metadata, blob });
36
+ videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
37
+ }
38
+
39
+ $: if ($isMediaStreaming == MediaStreamStatus.CONNECTED) {
40
+ videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
41
+ }
42
+ async function grapBlobImg() {
43
+ const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
44
+ const videoW = videoEl.videoWidth;
45
+ const videoH = videoEl.videoHeight;
46
+ const aspectRatio = WIDTH / HEIGHT;
47
+
48
+ const ctx = canvas.getContext('2d') as OffscreenCanvasRenderingContext2D;
49
+ ctx.drawImage(
50
+ videoEl,
51
+ videoW / 2 - (videoH * aspectRatio) / 2,
52
+ 0,
53
+ videoH * aspectRatio,
54
+ videoH,
55
+ 0,
56
+ 0,
57
+ WIDTH,
58
+ HEIGHT
59
+ );
60
+ const blob = await canvas.convertToBlob({ type: 'image/jpeg', quality: 1 });
61
+ return blob;
62
+ }
63
  </script>
64
 
65
+ <video
66
+ class="aspect-square w-full object-cover"
67
+ bind:this={videoEl}
68
+ playsinline
69
+ autoplay
70
+ muted
71
+ loop
72
+ use:srcObject={mediaStream}
73
+ ></video>
frontend/src/lib/lcmLive.ts ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { writable } from 'svelte/store';
2
+ import { PUBLIC_BASE_URL, PUBLIC_WSS_URL } from '$env/static/public';
3
+
4
+ export const isStreaming = writable(false);
5
+ export const isLCMRunning = writable(false);
6
+
7
+
8
+ export enum LCMLiveStatus {
9
+ INIT = "init",
10
+ CONNECTED = "connected",
11
+ DISCONNECTED = "disconnected",
12
+ }
13
+
14
+ interface lcmLive {
15
+ streamId: string | null;
16
+ status: LCMLiveStatus
17
+ }
18
+
19
+ const initialState: lcmLive = {
20
+ streamId: null,
21
+ status: LCMLiveStatus.INIT
22
+ };
23
+
24
+ export const lcmLiveState = writable(initialState);
25
+
26
+ let websocket: WebSocket | null = null;
27
+ export const lcmLiveActions = {
28
+ async start() {
29
+
30
+ isLCMRunning.set(true);
31
+ try {
32
+ const websocketURL = PUBLIC_WSS_URL ? PUBLIC_WSS_URL : `${window.location.protocol === "https:" ? "wss" : "ws"
33
+ }:${window.location.host}/ws`;
34
+
35
+ websocket = new WebSocket(websocketURL);
36
+ websocket.onopen = () => {
37
+ console.log("Connected to websocket");
38
+ };
39
+ websocket.onclose = () => {
40
+ lcmLiveState.update((state) => ({
41
+ ...state,
42
+ status: LCMLiveStatus.DISCONNECTED
43
+ }));
44
+ console.log("Disconnected from websocket");
45
+ isLCMRunning.set(false);
46
+ };
47
+ websocket.onerror = (err) => {
48
+ console.error(err);
49
+ };
50
+ websocket.onmessage = (event) => {
51
+ const data = JSON.parse(event.data);
52
+ console.log("WS: ", data);
53
+ switch (data.status) {
54
+ case "success":
55
+ break;
56
+ case "start":
57
+ const streamId = data.userId;
58
+ lcmLiveState.update((state) => ({
59
+ ...state,
60
+ status: LCMLiveStatus.CONNECTED,
61
+ streamId: streamId,
62
+ }));
63
+ break;
64
+ case "timeout":
65
+ console.log("timeout");
66
+ case "error":
67
+ console.log(data.message);
68
+ isLCMRunning.set(false);
69
+ }
70
+ };
71
+ lcmLiveState.update((state) => ({
72
+ ...state,
73
+ }));
74
+ } catch (err) {
75
+ console.error(err);
76
+ isLCMRunning.set(false);
77
+ }
78
+ },
79
+ send(data: Blob | { [key: string]: any }) {
80
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
81
+ if (data instanceof Blob) {
82
+ websocket.send(data);
83
+ } else {
84
+ websocket.send(JSON.stringify(data));
85
+ }
86
+ } else {
87
+ console.log("WebSocket not connected");
88
+ }
89
+ },
90
+ async stop() {
91
+
92
+ if (websocket) {
93
+ websocket.close();
94
+ }
95
+ websocket = null;
96
+ lcmLiveState.set({ status: LCMLiveStatus.DISCONNECTED, streamId: null });
97
+ isLCMRunning.set(false)
98
+ },
99
+ };
frontend/src/lib/mediaStream.ts ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { writable, type Writable } from 'svelte/store';
2
+
3
+ export enum MediaStreamStatus {
4
+ INIT = "init",
5
+ CONNECTED = "connected",
6
+ DISCONNECTED = "disconnected",
7
+ }
8
+ export const onFrameChangeStore: Writable<{ now: Number, metadata: VideoFrameCallbackMetadata, blob: Blob }> = writable();
9
+ export const isMediaStreaming = writable(MediaStreamStatus.INIT);
10
+
11
+ interface mediaStream {
12
+ mediaStream: MediaStream | null;
13
+ status: MediaStreamStatus
14
+ devices: MediaDeviceInfo[];
15
+ }
16
+
17
+ const initialState: mediaStream = {
18
+ mediaStream: null,
19
+ status: MediaStreamStatus.INIT,
20
+ devices: [],
21
+ };
22
+
23
+ export const mediaStreamState = writable(initialState);
24
+
25
+ export const mediaStreamActions = {
26
+ async enumerateDevices() {
27
+ console.log("Enumerating devices");
28
+ await navigator.mediaDevices.enumerateDevices()
29
+ .then(devices => {
30
+ const cameras = devices.filter(device => device.kind === 'videoinput');
31
+ console.log("Cameras: ", cameras);
32
+ mediaStreamState.update((state) => ({
33
+ ...state,
34
+ devices: cameras,
35
+ }));
36
+ })
37
+ .catch(err => {
38
+ console.error(err);
39
+ });
40
+ },
41
+ async start(mediaDevicedID?: string) {
42
+ const constraints = {
43
+ audio: false,
44
+ video: {
45
+ width: 1024, height: 1024, deviceId: mediaDevicedID
46
+ }
47
+ };
48
+
49
+ await navigator.mediaDevices
50
+ .getUserMedia(constraints)
51
+ .then((mediaStream) => {
52
+ mediaStreamState.update((state) => ({
53
+ ...state,
54
+ mediaStream: mediaStream,
55
+ status: MediaStreamStatus.CONNECTED,
56
+ }));
57
+ isMediaStreaming.set(MediaStreamStatus.CONNECTED);
58
+ })
59
+ .catch((err) => {
60
+ console.error(`${err.name}: ${err.message}`);
61
+ isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
62
+ });
63
+ },
64
+ async switchCamera(mediaDevicedID: string) {
65
+ const constraints = {
66
+ audio: false,
67
+ video: { width: 1024, height: 1024, deviceId: mediaDevicedID }
68
+ };
69
+ await navigator.mediaDevices
70
+ .getUserMedia(constraints)
71
+ .then((mediaStream) => {
72
+ mediaStreamState.update((state) => ({
73
+ ...state,
74
+ mediaStream: mediaStream,
75
+ status: MediaStreamStatus.CONNECTED,
76
+ }));
77
+ })
78
+ .catch((err) => {
79
+ console.error(`${err.name}: ${err.message}`);
80
+ });
81
+ },
82
+ async stop() {
83
+ navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
84
+ mediaStream.getTracks().forEach((track) => track.stop());
85
+ });
86
+ mediaStreamState.update((state) => ({
87
+ ...state,
88
+ mediaStream: null,
89
+ status: MediaStreamStatus.DISCONNECTED,
90
+ }));
91
+ isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
92
+ },
93
+ };
frontend/src/lib/types.ts CHANGED
@@ -2,6 +2,7 @@ export const enum FieldType {
2
  range = "range",
3
  seed = "seed",
4
  textarea = "textarea",
 
5
  }
6
 
7
  export interface FieldProps {
@@ -13,6 +14,7 @@ export interface FieldProps {
13
  step?: number;
14
  disabled?: boolean;
15
  hide?: boolean;
 
16
  }
17
  export interface PipelineInfo {
18
  name: string;
 
2
  range = "range",
3
  seed = "seed",
4
  textarea = "textarea",
5
+ checkbox = "checkbox",
6
  }
7
 
8
  export interface FieldProps {
 
14
  step?: number;
15
  disabled?: boolean;
16
  hide?: boolean;
17
+ id: string;
18
  }
19
  export interface PipelineInfo {
20
  name: string;
frontend/src/lib/utils.ts ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export function LCMLive(webcamVideo, liveImage) {
2
+ let websocket: WebSocket;
3
+
4
+ async function start() {
5
+ return new Promise((resolve, reject) => {
6
+ const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
7
+ }:${window.location.host}/ws`;
8
+
9
+ const socket = new WebSocket(websocketURL);
10
+ socket.onopen = () => {
11
+ console.log("Connected to websocket");
12
+ };
13
+ socket.onclose = () => {
14
+ console.log("Disconnected from websocket");
15
+ stop();
16
+ resolve({ "status": "disconnected" });
17
+ };
18
+ socket.onerror = (err) => {
19
+ console.error(err);
20
+ reject(err);
21
+ };
22
+ socket.onmessage = (event) => {
23
+ const data = JSON.parse(event.data);
24
+ switch (data.status) {
25
+ case "success":
26
+ break;
27
+ case "start":
28
+ const userId = data.userId;
29
+ initVideoStream(userId);
30
+ break;
31
+ case "timeout":
32
+ stop();
33
+ resolve({ "status": "timeout" });
34
+ case "error":
35
+ stop();
36
+ reject(data.message);
37
+
38
+ }
39
+ };
40
+ websocket = socket;
41
+ })
42
+ }
43
+ function switchCamera() {
44
+ const constraints = {
45
+ audio: false,
46
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[webcamsEl.value].deviceId }
47
+ };
48
+ navigator.mediaDevices
49
+ .getUserMedia(constraints)
50
+ .then((mediaStream) => {
51
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
52
+ webcamVideo.srcObject = mediaStream;
53
+ webcamVideo.onloadedmetadata = () => {
54
+ webcamVideo.play();
55
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
56
+ };
57
+ })
58
+ .catch((err) => {
59
+ console.error(`${err.name}: ${err.message}`);
60
+ });
61
+ }
62
+
63
+ async function videoTimeUpdateHandler() {
64
+ const dimension = getValue("input[name=dimension]:checked");
65
+ const [WIDTH, HEIGHT] = JSON.parse(dimension);
66
+
67
+ const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
68
+ const videoW = webcamVideo.videoWidth;
69
+ const videoH = webcamVideo.videoHeight;
70
+ const aspectRatio = WIDTH / HEIGHT;
71
+
72
+ const ctx = canvas.getContext("2d");
73
+ ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
74
+ const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
75
+ websocket.send(blob);
76
+ websocket.send(JSON.stringify({
77
+ "seed": getValue("#seed"),
78
+ "prompt": getValue("#prompt"),
79
+ "guidance_scale": getValue("#guidance-scale"),
80
+ "strength": getValue("#strength"),
81
+ "steps": getValue("#steps"),
82
+ "lcm_steps": getValue("#lcm_steps"),
83
+ "width": WIDTH,
84
+ "height": HEIGHT,
85
+ "controlnet_scale": getValue("#controlnet_scale"),
86
+ "controlnet_start": getValue("#controlnet_start"),
87
+ "controlnet_end": getValue("#controlnet_end"),
88
+ "canny_low_threshold": getValue("#canny_low_threshold"),
89
+ "canny_high_threshold": getValue("#canny_high_threshold"),
90
+ "debug_canny": getValue("#debug_canny")
91
+ }));
92
+ }
93
+ let mediaDevices = [];
94
+ async function initVideoStream(userId) {
95
+ liveImage.src = `/stream/${userId}`;
96
+ await navigator.mediaDevices.enumerateDevices()
97
+ .then(devices => {
98
+ const cameras = devices.filter(device => device.kind === 'videoinput');
99
+ mediaDevices = cameras;
100
+ webcamsEl.innerHTML = "";
101
+ cameras.forEach((camera, index) => {
102
+ const option = document.createElement("option");
103
+ option.value = index;
104
+ option.innerText = camera.label;
105
+ webcamsEl.appendChild(option);
106
+ option.selected = index === 0;
107
+ });
108
+ webcamsEl.addEventListener("change", switchCamera);
109
+ })
110
+ .catch(err => {
111
+ console.error(err);
112
+ });
113
+ const constraints = {
114
+ audio: false,
115
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[0].deviceId }
116
+ };
117
+ navigator.mediaDevices
118
+ .getUserMedia(constraints)
119
+ .then((mediaStream) => {
120
+ webcamVideo.srcObject = mediaStream;
121
+ webcamVideo.onloadedmetadata = () => {
122
+ webcamVideo.play();
123
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
124
+ };
125
+ })
126
+ .catch((err) => {
127
+ console.error(`${err.name}: ${err.message}`);
128
+ });
129
+ }
130
+
131
+
132
+ async function stop() {
133
+ websocket.close();
134
+ navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
135
+ mediaStream.getTracks().forEach((track) => track.stop());
136
+ });
137
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
138
+ webcamsEl.removeEventListener("change", switchCamera);
139
+ webcamVideo.srcObject = null;
140
+ }
141
+ return {
142
+ start,
143
+ stop
144
+ }
145
+ }
frontend/src/routes/+page.svelte CHANGED
@@ -7,6 +7,13 @@
7
  import Button from '$lib/components/Button.svelte';
8
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
9
  import Spinner from '$lib/icons/spinner.svelte';
 
 
 
 
 
 
 
10
 
11
  let pipelineParams: FieldProps[];
12
  let pipelineInfo: PipelineInfo;
@@ -21,11 +28,58 @@
21
  pipelineParams = Object.values(settings.input_params.properties);
22
  pipelineInfo = settings.info.properties;
23
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
 
24
  console.log('SETTINGS', pipelineInfo);
25
  }
26
 
27
- $: {
28
- console.log('PARENT', pipelineValues);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  }
30
  </script>
31
 
@@ -58,19 +112,26 @@
58
  </p>
59
  </article>
60
  {#if pipelineParams}
61
- <h2 class="font-medium">Prompt</h2>
62
- <p class="text-sm text-gray-500">
63
- Change the prompt to generate different images, accepts <a
64
- href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
65
- target="_blank"
66
- class="text-blue-500 underline hover:no-underline">Compel</a
67
- > syntax.
68
- </p>
 
 
69
  <PipelineOptions {pipelineParams} bind:pipelineValues></PipelineOptions>
70
  <div class="flex gap-3">
71
- <Button>Start</Button>
72
- <Button>Stop</Button>
73
- <Button classList={'ml-auto'}>Snapshot</Button>
 
 
 
 
 
74
  </div>
75
 
76
  <ImagePlayer>
 
7
  import Button from '$lib/components/Button.svelte';
8
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
9
  import Spinner from '$lib/icons/spinner.svelte';
10
+ import { isLCMRunning, lcmLiveState, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
11
+ import {
12
+ mediaStreamState,
13
+ mediaStreamActions,
14
+ isMediaStreaming,
15
+ onFrameChangeStore
16
+ } from '$lib/mediaStream';
17
 
18
  let pipelineParams: FieldProps[];
19
  let pipelineInfo: PipelineInfo;
 
28
  pipelineParams = Object.values(settings.input_params.properties);
29
  pipelineInfo = settings.info.properties;
30
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
31
+ console.log('PARAMS', pipelineParams);
32
  console.log('SETTINGS', pipelineInfo);
33
  }
34
 
35
+ // $: {
36
+ // console.log('isLCMRunning', $isLCMRunning);
37
+ // }
38
+ // $: {
39
+ // console.log('lcmLiveState', $lcmLiveState);
40
+ // }
41
+ // $: {
42
+ // console.log('mediaStreamState', $mediaStreamState);
43
+ // }
44
+ // $: if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
45
+ // lcmLiveActions.send(pipelineValues);
46
+ // }
47
+ onFrameChangeStore.subscribe(async (frame) => {
48
+ if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
49
+ lcmLiveActions.send(pipelineValues);
50
+ lcmLiveActions.send(frame.blob);
51
+ }
52
+ });
53
+ let startBt: Button;
54
+ let stopBt: Button;
55
+ let snapShotBt: Button;
56
+
57
+ async function toggleLcmLive() {
58
+ if (!$isLCMRunning) {
59
+ await mediaStreamActions.enumerateDevices();
60
+ await mediaStreamActions.start();
61
+ lcmLiveActions.start();
62
+ } else {
63
+ mediaStreamActions.stop();
64
+ lcmLiveActions.stop();
65
+ }
66
+ }
67
+ async function startLcmLive() {
68
+ try {
69
+ $isLCMRunning = true;
70
+ // const res = await lcmLive.start();
71
+ $isLCMRunning = false;
72
+ // if (res.status === "timeout")
73
+ // toggleMessage("success")
74
+ } catch (err) {
75
+ console.log(err);
76
+ // toggleMessage("error")
77
+ $isLCMRunning = false;
78
+ }
79
+ }
80
+ async function stopLcmLive() {
81
+ // await lcmLive.stop();
82
+ $isLCMRunning = false;
83
  }
84
  </script>
85
 
 
112
  </p>
113
  </article>
114
  {#if pipelineParams}
115
+ <header>
116
+ <h2 class="font-medium">Prompt</h2>
117
+ <p class="text-sm text-gray-500">
118
+ Change the prompt to generate different images, accepts <a
119
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
120
+ target="_blank"
121
+ class="text-blue-500 underline hover:no-underline">Compel</a
122
+ > syntax.
123
+ </p>
124
+ </header>
125
  <PipelineOptions {pipelineParams} bind:pipelineValues></PipelineOptions>
126
  <div class="flex gap-3">
127
+ <Button on:click={toggleLcmLive}>
128
+ {#if $isLCMRunning}
129
+ Stop
130
+ {:else}
131
+ Start
132
+ {/if}
133
+ </Button>
134
+ <Button disabled={$isLCMRunning} classList={'ml-auto'}>Snapshot</Button>
135
  </div>
136
 
137
  <ImagePlayer>
latent_consistency_controlnet.py DELETED
@@ -1,1100 +0,0 @@
1
- # from https://github.com/taabata/LCM_Inpaint_Outpaint_Comfy/blob/main/LCM/pipeline_cn.py
2
- # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
17
- # and https://github.com/hojonathanho/diffusion
18
-
19
- import math
20
- from dataclasses import dataclass
21
- from typing import Any, Dict, List, Optional, Tuple, Union
22
-
23
- import numpy as np
24
- import torch
25
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
-
27
- from diffusers import (
28
- AutoencoderKL,
29
- ConfigMixin,
30
- DiffusionPipeline,
31
- SchedulerMixin,
32
- UNet2DConditionModel,
33
- ControlNetModel,
34
- logging,
35
- )
36
- from diffusers.configuration_utils import register_to_config
37
- from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
38
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
39
- from diffusers.pipelines.stable_diffusion.safety_checker import (
40
- StableDiffusionSafetyChecker,
41
- )
42
- from diffusers.utils import BaseOutput
43
-
44
- from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
45
- from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
46
-
47
-
48
- import PIL.Image
49
-
50
-
51
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
-
53
-
54
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
55
- def retrieve_latents(encoder_output, generator):
56
- if hasattr(encoder_output, "latent_dist"):
57
- return encoder_output.latent_dist.sample(generator)
58
- elif hasattr(encoder_output, "latents"):
59
- return encoder_output.latents
60
- else:
61
- raise AttributeError("Could not access latents of provided encoder_output")
62
-
63
-
64
- class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
65
- _optional_components = ["scheduler"]
66
-
67
- def __init__(
68
- self,
69
- vae: AutoencoderKL,
70
- text_encoder: CLIPTextModel,
71
- tokenizer: CLIPTokenizer,
72
- controlnet: Union[
73
- ControlNetModel,
74
- List[ControlNetModel],
75
- Tuple[ControlNetModel],
76
- MultiControlNetModel,
77
- ],
78
- unet: UNet2DConditionModel,
79
- scheduler: "LCMScheduler",
80
- safety_checker: StableDiffusionSafetyChecker,
81
- feature_extractor: CLIPImageProcessor,
82
- requires_safety_checker: bool = True,
83
- ):
84
- super().__init__()
85
-
86
- scheduler = (
87
- scheduler
88
- if scheduler is not None
89
- else LCMScheduler_X(
90
- beta_start=0.00085,
91
- beta_end=0.0120,
92
- beta_schedule="scaled_linear",
93
- prediction_type="epsilon",
94
- )
95
- )
96
-
97
- self.register_modules(
98
- vae=vae,
99
- text_encoder=text_encoder,
100
- tokenizer=tokenizer,
101
- unet=unet,
102
- controlnet=controlnet,
103
- scheduler=scheduler,
104
- safety_checker=safety_checker,
105
- feature_extractor=feature_extractor,
106
- )
107
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
108
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
109
- self.control_image_processor = VaeImageProcessor(
110
- vae_scale_factor=self.vae_scale_factor,
111
- do_convert_rgb=True,
112
- do_normalize=False,
113
- )
114
-
115
- def _encode_prompt(
116
- self,
117
- prompt,
118
- device,
119
- num_images_per_prompt,
120
- prompt_embeds: None,
121
- ):
122
- r"""
123
- Encodes the prompt into text encoder hidden states.
124
- Args:
125
- prompt (`str` or `List[str]`, *optional*):
126
- prompt to be encoded
127
- device: (`torch.device`):
128
- torch device
129
- num_images_per_prompt (`int`):
130
- number of images that should be generated per prompt
131
- prompt_embeds (`torch.FloatTensor`, *optional*):
132
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
133
- provided, text embeddings will be generated from `prompt` input argument.
134
- """
135
-
136
- if prompt is not None and isinstance(prompt, str):
137
- pass
138
- elif prompt is not None and isinstance(prompt, list):
139
- len(prompt)
140
- else:
141
- prompt_embeds.shape[0]
142
-
143
- if prompt_embeds is None:
144
- text_inputs = self.tokenizer(
145
- prompt,
146
- padding="max_length",
147
- max_length=self.tokenizer.model_max_length,
148
- truncation=True,
149
- return_tensors="pt",
150
- )
151
- text_input_ids = text_inputs.input_ids
152
- untruncated_ids = self.tokenizer(
153
- prompt, padding="longest", return_tensors="pt"
154
- ).input_ids
155
-
156
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
157
- -1
158
- ] and not torch.equal(text_input_ids, untruncated_ids):
159
- removed_text = self.tokenizer.batch_decode(
160
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
161
- )
162
- logger.warning(
163
- "The following part of your input was truncated because CLIP can only handle sequences up to"
164
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
165
- )
166
-
167
- if (
168
- hasattr(self.text_encoder.config, "use_attention_mask")
169
- and self.text_encoder.config.use_attention_mask
170
- ):
171
- attention_mask = text_inputs.attention_mask.to(device)
172
- else:
173
- attention_mask = None
174
-
175
- prompt_embeds = self.text_encoder(
176
- text_input_ids.to(device),
177
- attention_mask=attention_mask,
178
- )
179
- prompt_embeds = prompt_embeds[0]
180
-
181
- if self.text_encoder is not None:
182
- prompt_embeds_dtype = self.text_encoder.dtype
183
- elif self.unet is not None:
184
- prompt_embeds_dtype = self.unet.dtype
185
- else:
186
- prompt_embeds_dtype = prompt_embeds.dtype
187
-
188
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
189
-
190
- bs_embed, seq_len, _ = prompt_embeds.shape
191
- # duplicate text embeddings for each generation per prompt, using mps friendly method
192
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
193
- prompt_embeds = prompt_embeds.view(
194
- bs_embed * num_images_per_prompt, seq_len, -1
195
- )
196
-
197
- # Don't need to get uncond prompt embedding because of LCM Guided Distillation
198
- return prompt_embeds
199
-
200
- def run_safety_checker(self, image, device, dtype):
201
- if self.safety_checker is None:
202
- has_nsfw_concept = None
203
- else:
204
- if torch.is_tensor(image):
205
- feature_extractor_input = self.image_processor.postprocess(
206
- image, output_type="pil"
207
- )
208
- else:
209
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
210
- safety_checker_input = self.feature_extractor(
211
- feature_extractor_input, return_tensors="pt"
212
- ).to(device)
213
- image, has_nsfw_concept = self.safety_checker(
214
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
215
- )
216
- return image, has_nsfw_concept
217
-
218
- def prepare_control_image(
219
- self,
220
- image,
221
- width,
222
- height,
223
- batch_size,
224
- num_images_per_prompt,
225
- device,
226
- dtype,
227
- do_classifier_free_guidance=False,
228
- guess_mode=False,
229
- ):
230
- image = self.control_image_processor.preprocess(
231
- image, height=height, width=width
232
- ).to(dtype=dtype)
233
- image_batch_size = image.shape[0]
234
-
235
- if image_batch_size == 1:
236
- repeat_by = batch_size
237
- else:
238
- # image batch size is the same as prompt batch size
239
- repeat_by = num_images_per_prompt
240
-
241
- image = image.repeat_interleave(repeat_by, dim=0)
242
-
243
- image = image.to(device=device, dtype=dtype)
244
-
245
- if do_classifier_free_guidance and not guess_mode:
246
- image = torch.cat([image] * 2)
247
-
248
- return image
249
-
250
- def prepare_latents(
251
- self,
252
- image,
253
- timestep,
254
- batch_size,
255
- num_channels_latents,
256
- height,
257
- width,
258
- dtype,
259
- device,
260
- latents=None,
261
- generator=None,
262
- ):
263
- shape = (
264
- batch_size,
265
- num_channels_latents,
266
- height // self.vae_scale_factor,
267
- width // self.vae_scale_factor,
268
- )
269
-
270
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
271
- raise ValueError(
272
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
273
- )
274
-
275
- image = image.to(device=device, dtype=dtype)
276
-
277
- # batch_size = batch_size * num_images_per_prompt
278
-
279
- if image.shape[1] == 4:
280
- init_latents = image
281
-
282
- else:
283
- if isinstance(generator, list) and len(generator) != batch_size:
284
- raise ValueError(
285
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
286
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
287
- )
288
-
289
- elif isinstance(generator, list):
290
- init_latents = [
291
- retrieve_latents(
292
- self.vae.encode(image[i : i + 1]), generator=generator[i]
293
- )
294
- for i in range(batch_size)
295
- ]
296
- init_latents = torch.cat(init_latents, dim=0)
297
- else:
298
- init_latents = retrieve_latents(
299
- self.vae.encode(image), generator=generator
300
- )
301
-
302
- init_latents = self.vae.config.scaling_factor * init_latents
303
-
304
- if (
305
- batch_size > init_latents.shape[0]
306
- and batch_size % init_latents.shape[0] == 0
307
- ):
308
- # expand init_latents for batch_size
309
- deprecation_message = (
310
- f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
311
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
312
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
313
- " your script to pass as many initial images as text prompts to suppress this warning."
314
- )
315
- # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
316
- additional_image_per_prompt = batch_size // init_latents.shape[0]
317
- init_latents = torch.cat(
318
- [init_latents] * additional_image_per_prompt, dim=0
319
- )
320
- elif (
321
- batch_size > init_latents.shape[0]
322
- and batch_size % init_latents.shape[0] != 0
323
- ):
324
- raise ValueError(
325
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
326
- )
327
- else:
328
- init_latents = torch.cat([init_latents], dim=0)
329
-
330
- shape = init_latents.shape
331
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
332
-
333
- # get latents
334
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
335
- latents = init_latents
336
-
337
- return latents
338
-
339
- if latents is None:
340
- latents = torch.randn(shape, dtype=dtype).to(device)
341
- else:
342
- latents = latents.to(device)
343
- # scale the initial noise by the standard deviation required by the scheduler
344
- latents = latents * self.scheduler.init_noise_sigma
345
- return latents
346
-
347
- def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
348
- """
349
- see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
350
- Args:
351
- timesteps: torch.Tensor: generate embedding vectors at these timesteps
352
- embedding_dim: int: dimension of the embeddings to generate
353
- dtype: data type of the generated embeddings
354
- Returns:
355
- embedding vectors with shape `(len(timesteps), embedding_dim)`
356
- """
357
- assert len(w.shape) == 1
358
- w = w * 1000.0
359
-
360
- half_dim = embedding_dim // 2
361
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
362
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
363
- emb = w.to(dtype)[:, None] * emb[None, :]
364
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
365
- if embedding_dim % 2 == 1: # zero pad
366
- emb = torch.nn.functional.pad(emb, (0, 1))
367
- assert emb.shape == (w.shape[0], embedding_dim)
368
- return emb
369
-
370
- def get_timesteps(self, num_inference_steps, strength, device):
371
- # get the original timestep using init_timestep
372
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
373
-
374
- t_start = max(num_inference_steps - init_timestep, 0)
375
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
376
-
377
- return timesteps, num_inference_steps - t_start
378
-
379
- @torch.no_grad()
380
- def __call__(
381
- self,
382
- prompt: Union[str, List[str]] = None,
383
- image: PipelineImageInput = None,
384
- control_image: PipelineImageInput = None,
385
- strength: float = 0.8,
386
- height: Optional[int] = 768,
387
- width: Optional[int] = 768,
388
- guidance_scale: float = 7.5,
389
- num_images_per_prompt: Optional[int] = 1,
390
- latents: Optional[torch.FloatTensor] = None,
391
- generator: Optional[torch.Generator] = None,
392
- num_inference_steps: int = 4,
393
- lcm_origin_steps: int = 50,
394
- prompt_embeds: Optional[torch.FloatTensor] = None,
395
- output_type: Optional[str] = "pil",
396
- return_dict: bool = True,
397
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
398
- controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
399
- guess_mode: bool = True,
400
- control_guidance_start: Union[float, List[float]] = 0.0,
401
- control_guidance_end: Union[float, List[float]] = 1.0,
402
- ):
403
- controlnet = (
404
- self.controlnet._orig_mod
405
- if is_compiled_module(self.controlnet)
406
- else self.controlnet
407
- )
408
- # 0. Default height and width to unet
409
- height = height or self.unet.config.sample_size * self.vae_scale_factor
410
- width = width or self.unet.config.sample_size * self.vae_scale_factor
411
- if not isinstance(control_guidance_start, list) and isinstance(
412
- control_guidance_end, list
413
- ):
414
- control_guidance_start = len(control_guidance_end) * [
415
- control_guidance_start
416
- ]
417
- elif not isinstance(control_guidance_end, list) and isinstance(
418
- control_guidance_start, list
419
- ):
420
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
421
- elif not isinstance(control_guidance_start, list) and not isinstance(
422
- control_guidance_end, list
423
- ):
424
- mult = (
425
- len(controlnet.nets)
426
- if isinstance(controlnet, MultiControlNetModel)
427
- else 1
428
- )
429
- control_guidance_start, control_guidance_end = mult * [
430
- control_guidance_start
431
- ], mult * [control_guidance_end]
432
- # 2. Define call parameters
433
- if prompt is not None and isinstance(prompt, str):
434
- batch_size = 1
435
- elif prompt is not None and isinstance(prompt, list):
436
- batch_size = len(prompt)
437
- else:
438
- batch_size = prompt_embeds.shape[0]
439
-
440
- device = self._execution_device
441
- # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
442
- global_pool_conditions = (
443
- controlnet.config.global_pool_conditions
444
- if isinstance(controlnet, ControlNetModel)
445
- else controlnet.nets[0].config.global_pool_conditions
446
- )
447
- guess_mode = guess_mode or global_pool_conditions
448
- # 3. Encode input prompt
449
- prompt_embeds = self._encode_prompt(
450
- prompt,
451
- device,
452
- num_images_per_prompt,
453
- prompt_embeds=prompt_embeds,
454
- )
455
-
456
- # 3.5 encode image
457
- image = self.image_processor.preprocess(image)
458
-
459
- if isinstance(controlnet, ControlNetModel):
460
- control_image = self.prepare_control_image(
461
- image=control_image,
462
- width=width,
463
- height=height,
464
- batch_size=batch_size * num_images_per_prompt,
465
- num_images_per_prompt=num_images_per_prompt,
466
- device=device,
467
- dtype=controlnet.dtype,
468
- guess_mode=guess_mode,
469
- )
470
- elif isinstance(controlnet, MultiControlNetModel):
471
- control_images = []
472
-
473
- for control_image_ in control_image:
474
- control_image_ = self.prepare_control_image(
475
- image=control_image_,
476
- width=width,
477
- height=height,
478
- batch_size=batch_size * num_images_per_prompt,
479
- num_images_per_prompt=num_images_per_prompt,
480
- device=device,
481
- dtype=controlnet.dtype,
482
- do_classifier_free_guidance=do_classifier_free_guidance,
483
- guess_mode=guess_mode,
484
- )
485
-
486
- control_images.append(control_image_)
487
-
488
- control_image = control_images
489
- else:
490
- assert False
491
-
492
- # 4. Prepare timesteps
493
- self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
494
- # timesteps = self.scheduler.timesteps
495
- # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
496
- timesteps = self.scheduler.timesteps
497
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
498
-
499
- print("timesteps: ", timesteps)
500
-
501
- # 5. Prepare latent variable
502
- num_channels_latents = self.unet.config.in_channels
503
- latents = self.prepare_latents(
504
- image,
505
- latent_timestep,
506
- batch_size * num_images_per_prompt,
507
- num_channels_latents,
508
- height,
509
- width,
510
- prompt_embeds.dtype,
511
- device,
512
- latents,
513
- )
514
- bs = batch_size * num_images_per_prompt
515
-
516
- # 6. Get Guidance Scale Embedding
517
- w = torch.tensor(guidance_scale).repeat(bs)
518
- w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
519
- device=device, dtype=latents.dtype
520
- )
521
- controlnet_keep = []
522
- for i in range(len(timesteps)):
523
- keeps = [
524
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
525
- for s, e in zip(control_guidance_start, control_guidance_end)
526
- ]
527
- controlnet_keep.append(
528
- keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
529
- )
530
- # 7. LCM MultiStep Sampling Loop:
531
- with self.progress_bar(total=num_inference_steps) as progress_bar:
532
- for i, t in enumerate(timesteps):
533
- ts = torch.full((bs,), t, device=device, dtype=torch.long)
534
- latents = latents.to(prompt_embeds.dtype)
535
- if guess_mode:
536
- # Infer ControlNet only for the conditional batch.
537
- control_model_input = latents
538
- control_model_input = self.scheduler.scale_model_input(
539
- control_model_input, ts
540
- )
541
- controlnet_prompt_embeds = prompt_embeds
542
- else:
543
- control_model_input = latents
544
- controlnet_prompt_embeds = prompt_embeds
545
- if isinstance(controlnet_keep[i], list):
546
- cond_scale = [
547
- c * s
548
- for c, s in zip(
549
- controlnet_conditioning_scale, controlnet_keep[i]
550
- )
551
- ]
552
- else:
553
- controlnet_cond_scale = controlnet_conditioning_scale
554
- if isinstance(controlnet_cond_scale, list):
555
- controlnet_cond_scale = controlnet_cond_scale[0]
556
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
557
-
558
- down_block_res_samples, mid_block_res_sample = self.controlnet(
559
- control_model_input,
560
- ts,
561
- encoder_hidden_states=controlnet_prompt_embeds,
562
- controlnet_cond=control_image,
563
- conditioning_scale=cond_scale,
564
- guess_mode=guess_mode,
565
- return_dict=False,
566
- )
567
- # model prediction (v-prediction, eps, x)
568
- model_pred = self.unet(
569
- latents,
570
- ts,
571
- timestep_cond=w_embedding,
572
- encoder_hidden_states=prompt_embeds,
573
- cross_attention_kwargs=cross_attention_kwargs,
574
- down_block_additional_residuals=down_block_res_samples,
575
- mid_block_additional_residual=mid_block_res_sample,
576
- return_dict=False,
577
- )[0]
578
-
579
- # compute the previous noisy sample x_t -> x_t-1
580
- latents, denoised = self.scheduler.step(
581
- model_pred, i, t, latents, return_dict=False
582
- )
583
-
584
- # # call the callback, if provided
585
- # if i == len(timesteps) - 1:
586
- progress_bar.update()
587
-
588
- denoised = denoised.to(prompt_embeds.dtype)
589
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
590
- self.unet.to("cpu")
591
- self.controlnet.to("cpu")
592
- torch.cuda.empty_cache()
593
- if not output_type == "latent":
594
- image = self.vae.decode(
595
- denoised / self.vae.config.scaling_factor, return_dict=False
596
- )[0]
597
- image, has_nsfw_concept = self.run_safety_checker(
598
- image, device, prompt_embeds.dtype
599
- )
600
- else:
601
- image = denoised
602
- has_nsfw_concept = None
603
-
604
- if has_nsfw_concept is None:
605
- do_denormalize = [True] * image.shape[0]
606
- else:
607
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
608
-
609
- image = self.image_processor.postprocess(
610
- image, output_type=output_type, do_denormalize=do_denormalize
611
- )
612
-
613
- if not return_dict:
614
- return (image, has_nsfw_concept)
615
-
616
- return StableDiffusionPipelineOutput(
617
- images=image, nsfw_content_detected=has_nsfw_concept
618
- )
619
-
620
-
621
- @dataclass
622
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
623
- class LCMSchedulerOutput(BaseOutput):
624
- """
625
- Output class for the scheduler's `step` function output.
626
- Args:
627
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
628
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
629
- denoising loop.
630
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
631
- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
632
- `pred_original_sample` can be used to preview progress or for guidance.
633
- """
634
-
635
- prev_sample: torch.FloatTensor
636
- denoised: Optional[torch.FloatTensor] = None
637
-
638
-
639
- # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
640
- def betas_for_alpha_bar(
641
- num_diffusion_timesteps,
642
- max_beta=0.999,
643
- alpha_transform_type="cosine",
644
- ):
645
- """
646
- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
647
- (1-beta) over time from t = [0,1].
648
- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
649
- to that part of the diffusion process.
650
- Args:
651
- num_diffusion_timesteps (`int`): the number of betas to produce.
652
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
653
- prevent singularities.
654
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
655
- Choose from `cosine` or `exp`
656
- Returns:
657
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
658
- """
659
- if alpha_transform_type == "cosine":
660
-
661
- def alpha_bar_fn(t):
662
- return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
663
-
664
- elif alpha_transform_type == "exp":
665
-
666
- def alpha_bar_fn(t):
667
- return math.exp(t * -12.0)
668
-
669
- else:
670
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
671
-
672
- betas = []
673
- for i in range(num_diffusion_timesteps):
674
- t1 = i / num_diffusion_timesteps
675
- t2 = (i + 1) / num_diffusion_timesteps
676
- betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
677
- return torch.tensor(betas, dtype=torch.float32)
678
-
679
-
680
- def rescale_zero_terminal_snr(betas):
681
- """
682
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
683
- Args:
684
- betas (`torch.FloatTensor`):
685
- the betas that the scheduler is being initialized with.
686
- Returns:
687
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
688
- """
689
- # Convert betas to alphas_bar_sqrt
690
- alphas = 1.0 - betas
691
- alphas_cumprod = torch.cumprod(alphas, dim=0)
692
- alphas_bar_sqrt = alphas_cumprod.sqrt()
693
-
694
- # Store old values.
695
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
696
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
697
-
698
- # Shift so the last timestep is zero.
699
- alphas_bar_sqrt -= alphas_bar_sqrt_T
700
-
701
- # Scale so the first timestep is back to the old value.
702
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
703
-
704
- # Convert alphas_bar_sqrt to betas
705
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
706
- alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
707
- alphas = torch.cat([alphas_bar[0:1], alphas])
708
- betas = 1 - alphas
709
-
710
- return betas
711
-
712
-
713
- class LCMScheduler_X(SchedulerMixin, ConfigMixin):
714
- """
715
- `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
716
- non-Markovian guidance.
717
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
718
- methods the library implements for all schedulers such as loading and saving.
719
- Args:
720
- num_train_timesteps (`int`, defaults to 1000):
721
- The number of diffusion steps to train the model.
722
- beta_start (`float`, defaults to 0.0001):
723
- The starting `beta` value of inference.
724
- beta_end (`float`, defaults to 0.02):
725
- The final `beta` value.
726
- beta_schedule (`str`, defaults to `"linear"`):
727
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
728
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
729
- trained_betas (`np.ndarray`, *optional*):
730
- Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
731
- clip_sample (`bool`, defaults to `True`):
732
- Clip the predicted sample for numerical stability.
733
- clip_sample_range (`float`, defaults to 1.0):
734
- The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
735
- set_alpha_to_one (`bool`, defaults to `True`):
736
- Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
737
- there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
738
- otherwise it uses the alpha value at step 0.
739
- steps_offset (`int`, defaults to 0):
740
- An offset added to the inference steps. You can use a combination of `offset=1` and
741
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
742
- Diffusion.
743
- prediction_type (`str`, defaults to `epsilon`, *optional*):
744
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
745
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
746
- Video](https://imagen.research.google/video/paper.pdf) paper).
747
- thresholding (`bool`, defaults to `False`):
748
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
749
- as Stable Diffusion.
750
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
751
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
752
- sample_max_value (`float`, defaults to 1.0):
753
- The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
754
- timestep_spacing (`str`, defaults to `"leading"`):
755
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
756
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
757
- rescale_betas_zero_snr (`bool`, defaults to `False`):
758
- Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
759
- dark samples instead of limiting it to samples with medium brightness. Loosely related to
760
- [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
761
- """
762
-
763
- # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
764
- order = 1
765
-
766
- @register_to_config
767
- def __init__(
768
- self,
769
- num_train_timesteps: int = 1000,
770
- beta_start: float = 0.0001,
771
- beta_end: float = 0.02,
772
- beta_schedule: str = "linear",
773
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
774
- clip_sample: bool = True,
775
- set_alpha_to_one: bool = True,
776
- steps_offset: int = 0,
777
- prediction_type: str = "epsilon",
778
- thresholding: bool = False,
779
- dynamic_thresholding_ratio: float = 0.995,
780
- clip_sample_range: float = 1.0,
781
- sample_max_value: float = 1.0,
782
- timestep_spacing: str = "leading",
783
- rescale_betas_zero_snr: bool = False,
784
- ):
785
- if trained_betas is not None:
786
- self.betas = torch.tensor(trained_betas, dtype=torch.float32)
787
- elif beta_schedule == "linear":
788
- self.betas = torch.linspace(
789
- beta_start, beta_end, num_train_timesteps, dtype=torch.float32
790
- )
791
- elif beta_schedule == "scaled_linear":
792
- # this schedule is very specific to the latent diffusion model.
793
- self.betas = (
794
- torch.linspace(
795
- beta_start**0.5,
796
- beta_end**0.5,
797
- num_train_timesteps,
798
- dtype=torch.float32,
799
- )
800
- ** 2
801
- )
802
- elif beta_schedule == "squaredcos_cap_v2":
803
- # Glide cosine schedule
804
- self.betas = betas_for_alpha_bar(num_train_timesteps)
805
- else:
806
- raise NotImplementedError(
807
- f"{beta_schedule} does is not implemented for {self.__class__}"
808
- )
809
-
810
- # Rescale for zero SNR
811
- if rescale_betas_zero_snr:
812
- self.betas = rescale_zero_terminal_snr(self.betas)
813
-
814
- self.alphas = 1.0 - self.betas
815
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
816
-
817
- # At every step in ddim, we are looking into the previous alphas_cumprod
818
- # For the final step, there is no previous alphas_cumprod because we are already at 0
819
- # `set_alpha_to_one` decides whether we set this parameter simply to one or
820
- # whether we use the final alpha of the "non-previous" one.
821
- self.final_alpha_cumprod = (
822
- torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
823
- )
824
-
825
- # standard deviation of the initial noise distribution
826
- self.init_noise_sigma = 1.0
827
-
828
- # setable values
829
- self.num_inference_steps = None
830
- self.timesteps = torch.from_numpy(
831
- np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
832
- )
833
-
834
- def scale_model_input(
835
- self, sample: torch.FloatTensor, timestep: Optional[int] = None
836
- ) -> torch.FloatTensor:
837
- """
838
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
839
- current timestep.
840
- Args:
841
- sample (`torch.FloatTensor`):
842
- The input sample.
843
- timestep (`int`, *optional*):
844
- The current timestep in the diffusion chain.
845
- Returns:
846
- `torch.FloatTensor`:
847
- A scaled input sample.
848
- """
849
- return sample
850
-
851
- def _get_variance(self, timestep, prev_timestep):
852
- alpha_prod_t = self.alphas_cumprod[timestep]
853
- alpha_prod_t_prev = (
854
- self.alphas_cumprod[prev_timestep]
855
- if prev_timestep >= 0
856
- else self.final_alpha_cumprod
857
- )
858
- beta_prod_t = 1 - alpha_prod_t
859
- beta_prod_t_prev = 1 - alpha_prod_t_prev
860
-
861
- variance = (beta_prod_t_prev / beta_prod_t) * (
862
- 1 - alpha_prod_t / alpha_prod_t_prev
863
- )
864
-
865
- return variance
866
-
867
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
868
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
869
- """
870
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
871
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
872
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
873
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
874
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
875
- https://arxiv.org/abs/2205.11487
876
- """
877
- dtype = sample.dtype
878
- batch_size, channels, height, width = sample.shape
879
-
880
- if dtype not in (torch.float32, torch.float64):
881
- sample = (
882
- sample.float()
883
- ) # upcast for quantile calculation, and clamp not implemented for cpu half
884
-
885
- # Flatten sample for doing quantile calculation along each image
886
- sample = sample.reshape(batch_size, channels * height * width)
887
-
888
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
889
-
890
- s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
891
- s = torch.clamp(
892
- s, min=1, max=self.config.sample_max_value
893
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
894
-
895
- s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
896
- sample = (
897
- torch.clamp(sample, -s, s) / s
898
- ) # "we threshold xt0 to the range [-s, s] and then divide by s"
899
-
900
- sample = sample.reshape(batch_size, channels, height, width)
901
- sample = sample.to(dtype)
902
-
903
- return sample
904
-
905
- def set_timesteps(
906
- self,
907
- stength,
908
- num_inference_steps: int,
909
- lcm_origin_steps: int,
910
- device: Union[str, torch.device] = None,
911
- ):
912
- """
913
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
914
- Args:
915
- num_inference_steps (`int`):
916
- The number of diffusion steps used when generating samples with a pre-trained model.
917
- """
918
-
919
- if num_inference_steps > self.config.num_train_timesteps:
920
- raise ValueError(
921
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
922
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
923
- f" maximal {self.config.num_train_timesteps} timesteps."
924
- )
925
-
926
- self.num_inference_steps = num_inference_steps
927
-
928
- # LCM Timesteps Setting: # Linear Spacing
929
- c = self.config.num_train_timesteps // lcm_origin_steps
930
- lcm_origin_timesteps = (
931
- np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
932
- ) # LCM Training Steps Schedule
933
- skipping_step = max(len(lcm_origin_timesteps) // num_inference_steps, 1)
934
- timesteps = lcm_origin_timesteps[::-skipping_step][
935
- :num_inference_steps
936
- ] # LCM Inference Steps Schedule
937
-
938
- self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
939
-
940
- def get_scalings_for_boundary_condition_discrete(self, t):
941
- self.sigma_data = 0.5 # Default: 0.5
942
-
943
- # By dividing 0.1: This is almost a delta function at t=0.
944
- c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
945
- c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
946
- return c_skip, c_out
947
-
948
- def step(
949
- self,
950
- model_output: torch.FloatTensor,
951
- timeindex: int,
952
- timestep: int,
953
- sample: torch.FloatTensor,
954
- eta: float = 0.0,
955
- use_clipped_model_output: bool = False,
956
- generator=None,
957
- variance_noise: Optional[torch.FloatTensor] = None,
958
- return_dict: bool = True,
959
- ) -> Union[LCMSchedulerOutput, Tuple]:
960
- """
961
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
962
- process from the learned model outputs (most often the predicted noise).
963
- Args:
964
- model_output (`torch.FloatTensor`):
965
- The direct output from learned diffusion model.
966
- timestep (`float`):
967
- The current discrete timestep in the diffusion chain.
968
- sample (`torch.FloatTensor`):
969
- A current instance of a sample created by the diffusion process.
970
- eta (`float`):
971
- The weight of noise for added noise in diffusion step.
972
- use_clipped_model_output (`bool`, defaults to `False`):
973
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
974
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
975
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
976
- `use_clipped_model_output` has no effect.
977
- generator (`torch.Generator`, *optional*):
978
- A random number generator.
979
- variance_noise (`torch.FloatTensor`):
980
- Alternative to generating noise with `generator` by directly providing the noise for the variance
981
- itself. Useful for methods such as [`CycleDiffusion`].
982
- return_dict (`bool`, *optional*, defaults to `True`):
983
- Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
984
- Returns:
985
- [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
986
- If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
987
- tuple is returned where the first element is the sample tensor.
988
- """
989
- if self.num_inference_steps is None:
990
- raise ValueError(
991
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
992
- )
993
-
994
- # 1. get previous step value
995
- prev_timeindex = timeindex + 1
996
- if prev_timeindex < len(self.timesteps):
997
- prev_timestep = self.timesteps[prev_timeindex]
998
- else:
999
- prev_timestep = timestep
1000
-
1001
- # 2. compute alphas, betas
1002
- alpha_prod_t = self.alphas_cumprod[timestep]
1003
- alpha_prod_t_prev = (
1004
- self.alphas_cumprod[prev_timestep]
1005
- if prev_timestep >= 0
1006
- else self.final_alpha_cumprod
1007
- )
1008
-
1009
- beta_prod_t = 1 - alpha_prod_t
1010
- beta_prod_t_prev = 1 - alpha_prod_t_prev
1011
-
1012
- # 3. Get scalings for boundary conditions
1013
- c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
1014
-
1015
- # 4. Different Parameterization:
1016
- parameterization = self.config.prediction_type
1017
-
1018
- if parameterization == "epsilon": # noise-prediction
1019
- pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
1020
-
1021
- elif parameterization == "sample": # x-prediction
1022
- pred_x0 = model_output
1023
-
1024
- elif parameterization == "v_prediction": # v-prediction
1025
- pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
1026
-
1027
- # 4. Denoise model output using boundary conditions
1028
- denoised = c_out * pred_x0 + c_skip * sample
1029
-
1030
- # 5. Sample z ~ N(0, I), For MultiStep Inference
1031
- # Noise is not used for one-step sampling.
1032
- if len(self.timesteps) > 1:
1033
- noise = torch.randn(model_output.shape).to(model_output.device)
1034
- prev_sample = (
1035
- alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
1036
- )
1037
- else:
1038
- prev_sample = denoised
1039
-
1040
- if not return_dict:
1041
- return (prev_sample, denoised)
1042
-
1043
- return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
1044
-
1045
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1046
- def add_noise(
1047
- self,
1048
- original_samples: torch.FloatTensor,
1049
- noise: torch.FloatTensor,
1050
- timesteps: torch.IntTensor,
1051
- ) -> torch.FloatTensor:
1052
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1053
- alphas_cumprod = self.alphas_cumprod.to(
1054
- device=original_samples.device, dtype=original_samples.dtype
1055
- )
1056
- timesteps = timesteps.to(original_samples.device)
1057
-
1058
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1059
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1060
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1061
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1062
-
1063
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1064
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1065
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1066
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1067
-
1068
- noisy_samples = (
1069
- sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1070
- )
1071
- return noisy_samples
1072
-
1073
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
1074
- def get_velocity(
1075
- self,
1076
- sample: torch.FloatTensor,
1077
- noise: torch.FloatTensor,
1078
- timesteps: torch.IntTensor,
1079
- ) -> torch.FloatTensor:
1080
- # Make sure alphas_cumprod and timestep have same device and dtype as sample
1081
- alphas_cumprod = self.alphas_cumprod.to(
1082
- device=sample.device, dtype=sample.dtype
1083
- )
1084
- timesteps = timesteps.to(sample.device)
1085
-
1086
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1087
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1088
- while len(sqrt_alpha_prod.shape) < len(sample.shape):
1089
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1090
-
1091
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1092
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1093
- while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
1094
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1095
-
1096
- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
1097
- return velocity
1098
-
1099
- def __len__(self):
1100
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipelines/controlnet.py CHANGED
@@ -1,8 +1,11 @@
1
- from diffusers import DiffusionPipeline, AutoencoderTiny
2
- from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
3
-
 
 
4
  from compel import Compel
5
  import torch
 
6
 
7
  try:
8
  import intel_extension_for_pytorch as ipex # type: ignore
@@ -11,80 +14,202 @@ except:
11
 
12
  import psutil
13
  from config import Args
14
- from pydantic import BaseModel
15
  from PIL import Image
16
- from typing import Callable
17
 
18
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
19
- WIDTH = 512
20
- HEIGHT = 512
 
 
21
 
22
 
23
  class Pipeline:
 
 
 
 
24
  class InputParams(BaseModel):
25
- seed: int = 2159232
26
- prompt: str
27
- guidance_scale: float = 8.0
28
- strength: float = 0.5
29
- steps: int = 4
30
- lcm_steps: int = 50
31
- width: int = WIDTH
32
- height: int = HEIGHT
33
-
34
- @staticmethod
35
- def create_pipeline(
36
- args: Args, device: torch.device, torch_dtype: torch.dtype
37
- ) -> Callable[["Pipeline.InputParams"], Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if args.safety_checker:
39
- pipe = DiffusionPipeline.from_pretrained(base_model)
 
 
40
  else:
41
- pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
 
 
 
 
42
  if args.use_taesd:
43
- pipe.vae = AutoencoderTiny.from_pretrained(
44
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
45
  )
46
-
47
- pipe.set_progress_bar_config(disable=True)
48
- pipe.to(device=device, dtype=torch_dtype)
49
- pipe.unet.to(memory_format=torch.channels_last)
50
 
51
  # check if computer has less than 64GB of RAM using sys or os
52
  if psutil.virtual_memory().total < 64 * 1024**3:
53
- pipe.enable_attention_slicing()
54
 
55
  if args.torch_compile:
56
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
57
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
 
 
 
 
58
 
59
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
 
 
 
 
60
 
61
- compel_proc = Compel(
62
- tokenizer=pipe.tokenizer,
63
- text_encoder=pipe.text_encoder,
64
  truncate_long_prompts=False,
65
  )
66
 
67
- def predict(params: "Pipeline.InputParams") -> Image.Image:
68
- generator = torch.manual_seed(params.seed)
69
- prompt_embeds = compel_proc(params.prompt)
70
- # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
71
- results = pipe(
72
- prompt_embeds=prompt_embeds,
73
- generator=generator,
74
- num_inference_steps=params.steps,
75
- guidance_scale=params.guidance_scale,
76
- width=params.width,
77
- height=params.height,
78
- original_inference_steps=params.lcm_steps,
79
- output_type="pil",
80
- )
81
- nsfw_content_detected = (
82
- results.nsfw_content_detected[0]
83
- if "nsfw_content_detected" in results
84
- else False
85
- )
86
- if nsfw_content_detected:
87
- return None
88
- return results.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- return predict
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ AutoencoderTiny,
4
+ ControlNetModel,
5
+ )
6
  from compel import Compel
7
  import torch
8
+ from pipelines.utils.canny_gpu import SobelOperator
9
 
10
  try:
11
  import intel_extension_for_pytorch as ipex # type: ignore
 
14
 
15
  import psutil
16
  from config import Args
17
+ from pydantic import BaseModel, Field
18
  from PIL import Image
 
19
 
20
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
21
+ taesd_model = "madebyollin/taesd"
22
+ controlnet_model = "lllyasviel/control_v11p_sd15_canny"
23
+
24
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
25
 
26
 
27
  class Pipeline:
28
+ class Info(BaseModel):
29
+ name: str = "txt2img"
30
+ description: str = "Generates an image from a text prompt"
31
+
32
  class InputParams(BaseModel):
33
+ prompt: str = Field(
34
+ default_prompt,
35
+ title="Prompt",
36
+ field="textarea",
37
+ id="prompt",
38
+ )
39
+ seed: int = Field(
40
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
41
+ )
42
+ steps: int = Field(
43
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
44
+ )
45
+ width: int = Field(
46
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
47
+ )
48
+ height: int = Field(
49
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
50
+ )
51
+ guidance_scale: float = Field(
52
+ 0.2,
53
+ min=0,
54
+ max=2,
55
+ step=0.001,
56
+ title="Guidance Scale",
57
+ field="range",
58
+ hide=True,
59
+ id="guidance_scale",
60
+ )
61
+ strength: float = Field(
62
+ 0.5,
63
+ min=0.25,
64
+ max=1.0,
65
+ step=0.001,
66
+ title="Strength",
67
+ field="range",
68
+ hide=True,
69
+ id="strength",
70
+ )
71
+ controlnet_scale: float = Field(
72
+ 0.8,
73
+ min=0,
74
+ max=1.0,
75
+ step=0.001,
76
+ title="Controlnet Scale",
77
+ field="range",
78
+ hide=True,
79
+ id="controlnet_scale",
80
+ )
81
+ controlnet_start: float = Field(
82
+ 0.0,
83
+ min=0,
84
+ max=1.0,
85
+ step=0.001,
86
+ title="Controlnet Start",
87
+ field="range",
88
+ hide=True,
89
+ id="controlnet_start",
90
+ )
91
+ controlnet_end: float = Field(
92
+ 1.0,
93
+ min=0,
94
+ max=1.0,
95
+ step=0.001,
96
+ title="Controlnet End",
97
+ field="range",
98
+ hide=True,
99
+ id="controlnet_end",
100
+ )
101
+ canny_low_threshold: float = Field(
102
+ 0.31,
103
+ min=0,
104
+ max=1.0,
105
+ step=0.001,
106
+ title="Canny Low Threshold",
107
+ field="range",
108
+ hide=True,
109
+ id="canny_low_threshold",
110
+ )
111
+ canny_high_threshold: float = Field(
112
+ 0.125,
113
+ min=0,
114
+ max=1.0,
115
+ step=0.001,
116
+ title="Canny High Threshold",
117
+ field="range",
118
+ hide=True,
119
+ id="canny_high_threshold",
120
+ )
121
+ debug_canny: bool = Field(
122
+ False,
123
+ title="Debug Canny",
124
+ field="checkbox",
125
+ hide=True,
126
+ id="debug_canny",
127
+ )
128
+ image: bool = True
129
+
130
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
131
+ controlnet_canny = ControlNetModel.from_pretrained(
132
+ controlnet_model, torch_dtype=torch_dtype
133
+ ).to(device)
134
  if args.safety_checker:
135
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
136
+ base_model, controlnet=controlnet_canny
137
+ )
138
  else:
139
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
140
+ base_model,
141
+ safety_checker=None,
142
+ controlnet=controlnet_canny,
143
+ )
144
  if args.use_taesd:
145
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
146
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
147
  )
148
+ self.canny_torch = SobelOperator(device=device)
149
+ self.pipe.set_progress_bar_config(disable=True)
150
+ self.pipe.to(device=device, dtype=torch_dtype)
151
+ self.pipe.unet.to(memory_format=torch.channels_last)
152
 
153
  # check if computer has less than 64GB of RAM using sys or os
154
  if psutil.virtual_memory().total < 64 * 1024**3:
155
+ self.pipe.enable_attention_slicing()
156
 
157
  if args.torch_compile:
158
+ self.pipe.unet = torch.compile(
159
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
160
+ )
161
+ self.pipe.vae = torch.compile(
162
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
163
+ )
164
 
165
+ self.pipe(
166
+ prompt="warmup",
167
+ image=[Image.new("RGB", (768, 768))],
168
+ control_image=[Image.new("RGB", (768, 768))],
169
+ )
170
 
171
+ self.compel_proc = Compel(
172
+ tokenizer=self.pipe.tokenizer,
173
+ text_encoder=self.pipe.text_encoder,
174
  truncate_long_prompts=False,
175
  )
176
 
177
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
178
+ generator = torch.manual_seed(params.seed)
179
+ prompt_embeds = self.compel_proc(params.prompt)
180
+ control_image = self.canny_torch(
181
+ params.image, params.canny_low_threshold, params.canny_high_threshold
182
+ )
183
+
184
+ results = self.pipe(
185
+ image=params.image,
186
+ control_image=control_image,
187
+ prompt_embeds=prompt_embeds,
188
+ generator=generator,
189
+ strength=params.strength,
190
+ num_inference_steps=params.steps,
191
+ guidance_scale=params.guidance_scale,
192
+ width=params.width,
193
+ height=params.height,
194
+ output_type="pil",
195
+ controlnet_conditioning_scale=params.controlnet_scale,
196
+ control_guidance_start=params.controlnet_start,
197
+ control_guidance_end=params.controlnet_end,
198
+ )
199
+
200
+ nsfw_content_detected = (
201
+ results.nsfw_content_detected[0]
202
+ if "nsfw_content_detected" in results
203
+ else False
204
+ )
205
+ if nsfw_content_detected:
206
+ return None
207
+ result_image = results.images[0]
208
+ if params.debug_canny:
209
+ # paste control_image on top of result_image
210
+ w0, h0 = (200, 200)
211
+ control_image = control_image.resize((w0, h0))
212
+ w1, h1 = result_image.size
213
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
214
 
215
+ return result_image
pipelines/txt2img.py CHANGED
@@ -11,7 +11,6 @@ import psutil
11
  from config import Args
12
  from pydantic import BaseModel, Field
13
  from PIL import Image
14
- from typing import Callable
15
 
16
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
17
  taesd_model = "madebyollin/taesd"
@@ -29,22 +28,19 @@ class Pipeline:
29
  default_prompt,
30
  title="Prompt",
31
  field="textarea",
 
32
  )
33
- seed: int = Field(2159232, min=0, title="Seed", field="seed", hide=True)
34
- strength: float = Field(
35
- 0.5,
36
- min=0,
37
- max=1,
38
- step=0.001,
39
- title="Strength",
40
- field="range",
41
- hide=True,
42
  )
43
-
44
- steps: int = Field(4, min=2, max=15, title="Steps", field="range", hide=True)
45
- width: int = Field(512, min=2, max=15, title="Width", disabled=True, hide=True)
46
  height: int = Field(
47
- 512, min=2, max=15, title="Height", disabled=True, hide=True
48
  )
49
  guidance_scale: float = Field(
50
  8.0,
@@ -54,6 +50,10 @@ class Pipeline:
54
  title="Guidance Scale",
55
  field="range",
56
  hide=True,
 
 
 
 
57
  )
58
 
59
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
 
11
  from config import Args
12
  from pydantic import BaseModel, Field
13
  from PIL import Image
 
14
 
15
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
16
  taesd_model = "madebyollin/taesd"
 
28
  default_prompt,
29
  title="Prompt",
30
  field="textarea",
31
+ id="prompt",
32
  )
33
+ seed: int = Field(
34
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
35
+ )
36
+ steps: int = Field(
37
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
38
+ )
39
+ width: int = Field(
40
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
 
41
  )
 
 
 
42
  height: int = Field(
43
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
44
  )
45
  guidance_scale: float = Field(
46
  8.0,
 
50
  title="Guidance Scale",
51
  field="range",
52
  hide=True,
53
+ id="guidance_scale",
54
+ )
55
+ image: bool = Field(
56
+ True, title="Image", field="checkbox", hide=True, id="image"
57
  )
58
 
59
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
canny_gpu.py → pipelines/utils/canny_gpu.py RENAMED
File without changes
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers==0.23.0
2
  transformers==4.34.1
3
  gradio==3.50.2
4
  --extra-index-url https://download.pytorch.org/whl/cu121;
 
1
+ git+https://github.com/huggingface/diffusers@c697f524761abd2314c030221a3ad2f7791eab4e
2
  transformers==4.34.1
3
  gradio==3.50.2
4
  --extra-index-url https://download.pytorch.org/whl/cu121;
user_queue.py CHANGED
@@ -1,18 +1,29 @@
1
  from typing import Dict, Union
2
  from uuid import UUID
3
- from asyncio import Queue
4
  from PIL import Image
5
- from typing import Tuple, Union
6
- from uuid import UUID
7
- from asyncio import Queue
8
  from PIL import Image
9
 
 
10
  UserId = UUID
 
11
 
12
- InputParams = dict
13
 
14
- QueueContent = Dict[str, Union[Image.Image, InputParams]]
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- UserQueueDict = Dict[UserId, Queue[QueueContent]]
17
 
18
- user_queue_map: UserQueueDict = {}
 
 
1
  from typing import Dict, Union
2
  from uuid import UUID
3
+ import asyncio
4
  from PIL import Image
5
+ from typing import Dict, Union
 
 
6
  from PIL import Image
7
 
8
+ InputParams = dict
9
  UserId = UUID
10
+ EventDataContent = Dict[str, InputParams]
11
 
 
12
 
13
+ class UserDataEvent:
14
+ def __init__(self):
15
+ self.data_event = asyncio.Event()
16
+ self.data_content: EventDataContent = {}
17
+
18
+ def update_data(self, new_data: EventDataContent):
19
+ self.data_content = new_data
20
+ self.data_event.set()
21
+
22
+ async def wait_for_data(self) -> EventDataContent:
23
+ await self.data_event.wait()
24
+ self.data_event.clear()
25
+ return self.data_content
26
 
 
27
 
28
+ UserDataEventMap = Dict[UserId, UserDataEvent]
29
+ user_data_events: UserDataEventMap = {}