radames commited on
Commit
f419fcc
·
1 Parent(s): 73b790b

live params!

Browse files
Files changed (2) hide show
  1. app-img2img.py +15 -16
  2. img2img/index.html +11 -13
app-img2img.py CHANGED
@@ -49,7 +49,7 @@ else:
49
  pipe.set_progress_bar_config(disable=True)
50
  pipe.to(torch_device="cuda", torch_dtype=torch.float32)
51
  pipe.unet.to(memory_format=torch.channels_last)
52
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
53
  user_queue_map = {}
54
 
55
  # for torch.compile
@@ -58,7 +58,7 @@ pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
58
  def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
59
  generator = torch.manual_seed(seed)
60
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
61
- num_inference_steps = 4
62
  results = pipe(
63
  prompt=prompt,
64
  # generator=generator,
@@ -66,7 +66,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
66
  strength=strength,
67
  num_inference_steps=num_inference_steps,
68
  guidance_scale=guidance_scale,
69
- lcm_origin_steps=30,
70
  output_type="pil",
71
  )
72
  nsfw_content_detected = (
@@ -111,11 +111,8 @@ async def websocket_endpoint(websocket: WebSocket):
111
  await websocket.send_json(
112
  {"status": "success", "message": "Connected", "userId": uid}
113
  )
114
- params = await websocket.receive_json()
115
- params = InputParams(**params)
116
  user_queue_map[uid] = {
117
- "queue": asyncio.Queue(),
118
- "params": params,
119
  }
120
  await websocket.send_json(
121
  {"status": "start", "message": "Start Streaming", "userId": uid}
@@ -148,19 +145,16 @@ async def stream(user_id: uuid.UUID):
148
  try:
149
  user_queue = user_queue_map[uid]
150
  queue = user_queue["queue"]
151
- params = user_queue["params"]
152
- seed = params.seed
153
- prompt = params.prompt
154
- strength = params.strength
155
- guidance_scale = params.guidance_scale
156
-
157
  async def generate():
158
  while True:
159
- input_image = await queue.get()
 
 
160
  if input_image is None:
161
  continue
162
 
163
- image = predict(input_image, prompt, guidance_scale, strength, seed)
164
  if image is None:
165
  continue
166
  frame_data = io.BytesIO()
@@ -190,6 +184,8 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
190
  try:
191
  while True:
192
  data = await websocket.receive_bytes()
 
 
193
  pil_image = Image.open(io.BytesIO(data))
194
 
195
  while not queue.empty():
@@ -197,7 +193,10 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
197
  queue.get_nowait()
198
  except asyncio.QueueEmpty:
199
  continue
200
- await queue.put(pil_image)
 
 
 
201
  if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
202
  await websocket.send_json(
203
  {
 
49
  pipe.set_progress_bar_config(disable=True)
50
  pipe.to(torch_device="cuda", torch_dtype=torch.float32)
51
  pipe.unet.to(memory_format=torch.channels_last)
52
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
53
  user_queue_map = {}
54
 
55
  # for torch.compile
 
58
  def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
59
  generator = torch.manual_seed(seed)
60
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
61
+ num_inference_steps = 3
62
  results = pipe(
63
  prompt=prompt,
64
  # generator=generator,
 
66
  strength=strength,
67
  num_inference_steps=num_inference_steps,
68
  guidance_scale=guidance_scale,
69
+ lcm_origin_steps=20,
70
  output_type="pil",
71
  )
72
  nsfw_content_detected = (
 
111
  await websocket.send_json(
112
  {"status": "success", "message": "Connected", "userId": uid}
113
  )
 
 
114
  user_queue_map[uid] = {
115
+ "queue": asyncio.Queue()
 
116
  }
117
  await websocket.send_json(
118
  {"status": "start", "message": "Start Streaming", "userId": uid}
 
145
  try:
146
  user_queue = user_queue_map[uid]
147
  queue = user_queue["queue"]
148
+
 
 
 
 
 
149
  async def generate():
150
  while True:
151
+ data = await queue.get()
152
+ input_image = data["image"]
153
+ params = data["params"]
154
  if input_image is None:
155
  continue
156
 
157
+ image = predict(input_image, params.prompt, params.guidance_scale, params.strength, params.seed)
158
  if image is None:
159
  continue
160
  frame_data = io.BytesIO()
 
184
  try:
185
  while True:
186
  data = await websocket.receive_bytes()
187
+ params = await websocket.receive_json()
188
+ params = InputParams(**params)
189
  pil_image = Image.open(io.BytesIO(data))
190
 
191
  while not queue.empty():
 
193
  queue.get_nowait()
194
  except asyncio.QueueEmpty:
195
  continue
196
+ await queue.put({
197
+ "image": pil_image,
198
+ "params": params
199
+ })
200
  if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
201
  await websocket.send_json(
202
  {
img2img/index.html CHANGED
@@ -21,10 +21,10 @@
21
  const queueSizeEl = document.querySelector("#queue_size");
22
  const errorEl = document.querySelector("#error");
23
 
24
- function LCMLive(webcamVideo, liveImage) {
25
  let websocket;
26
 
27
- async function start(params) {
28
  return new Promise((resolve, reject) => {
29
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
30
  }:${window.location.host}/ws`;
@@ -46,7 +46,6 @@
46
  const data = JSON.parse(event.data);
47
  switch (data.status) {
48
  case "success":
49
- socket.send(JSON.stringify(params));
50
  break;
51
  case "start":
52
  const userId = data.userId;
@@ -71,6 +70,12 @@
71
  ctx.drawImage(webcamVideo, 0, 0, canvas.width, canvas.height);
72
  const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
73
  websocket.send(blob);
 
 
 
 
 
 
74
  }
75
 
76
  function initVideoStream(userId) {
@@ -124,15 +129,11 @@
124
  }
125
 
126
 
127
- const lcmLive = LCMLive(videoEl, imageEl);
128
  startBtn.addEventListener("click", async () => {
129
  try {
130
- const seed = seedEl.value;
131
- const prompt = promptEl.value;
132
- const guidance_scale = guidanceEl.value;
133
- const strength = strengthEl.value;
134
  startBtn.disabled = true;
135
- const res = await lcmLive.start({ seed, prompt, guidance_scale, strength });
136
  startBtn.disabled = false;
137
  if (res.status === "timeout")
138
  toggleMessage("success")
@@ -176,9 +177,6 @@
176
  target="_blank" class="text-blue-500 hover:underline">Diffusers</a> with a MJPEG
177
  stream server.
178
  </p>
179
- <p class="text-sm">
180
- To change settings or prompt, stop the current stream and start a new one.
181
- </p>
182
  <p class="text-sm">
183
  There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
184
  real-time performance. Maximum queue size is 4. <a
@@ -218,7 +216,7 @@
218
  <input type="number" id="seed" name="seed" value="299792458"
219
  class="font-light border border-gray-700 text-right rounded-md p-2">
220
  <button
221
- onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
222
  class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
223
  Rand
224
  </button>
 
21
  const queueSizeEl = document.querySelector("#queue_size");
22
  const errorEl = document.querySelector("#error");
23
 
24
+ function LCMLive(webcamVideo, liveImage, seedEl, promptEl, guidanceEl, strengthEl) {
25
  let websocket;
26
 
27
+ async function start() {
28
  return new Promise((resolve, reject) => {
29
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
30
  }:${window.location.host}/ws`;
 
46
  const data = JSON.parse(event.data);
47
  switch (data.status) {
48
  case "success":
 
49
  break;
50
  case "start":
51
  const userId = data.userId;
 
70
  ctx.drawImage(webcamVideo, 0, 0, canvas.width, canvas.height);
71
  const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
72
  websocket.send(blob);
73
+ websocket.send(JSON.stringify({
74
+ "seed": seedEl.value,
75
+ "prompt": promptEl.value,
76
+ "guidance_scale": guidanceEl.value,
77
+ "strength": strengthEl.value
78
+ }));
79
  }
80
 
81
  function initVideoStream(userId) {
 
129
  }
130
 
131
 
132
+ const lcmLive = LCMLive(videoEl, imageEl, seedEl, promptEl, guidanceEl, strengthEl);
133
  startBtn.addEventListener("click", async () => {
134
  try {
 
 
 
 
135
  startBtn.disabled = true;
136
+ const res = await lcmLive.start();
137
  startBtn.disabled = false;
138
  if (res.status === "timeout")
139
  toggleMessage("success")
 
177
  target="_blank" class="text-blue-500 hover:underline">Diffusers</a> with a MJPEG
178
  stream server.
179
  </p>
 
 
 
180
  <p class="text-sm">
181
  There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
182
  real-time performance. Maximum queue size is 4. <a
 
216
  <input type="number" id="seed" name="seed" value="299792458"
217
  class="font-light border border-gray-700 text-right rounded-md p-2">
218
  <button
219
+ onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
220
  class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
221
  Rand
222
  </button>