prithivMLmods commited on
Commit
848917c
·
verified ·
1 Parent(s): a9a4f2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -380
app.py CHANGED
@@ -1,399 +1,234 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
  import gradio as gr
10
  import spaces
11
- import torch
12
  import numpy as np
 
 
 
13
  from PIL import Image
14
- import edge_tts
15
- import cv2
16
-
17
- from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
- Qwen2VLForConditionalGeneration,
22
- AutoProcessor,
23
- )
24
- from transformers.image_utils import load_image
25
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
-
27
- MAX_MAX_NEW_TOKENS = 2048
28
- DEFAULT_MAX_NEW_TOKENS = 1024
29
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
-
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
-
33
- # Load text-only model and tokenizer
34
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
38
- device_map="auto",
39
- torch_dtype=torch.bfloat16,
40
- )
41
- model.eval()
42
-
43
- TTS_VOICES = [
44
- "en-US-JennyNeural", # @tts1
45
- "en-US-GuyNeural", # @tts2
46
- ]
47
 
48
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
49
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
50
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
- MODEL_ID,
52
- trust_remote_code=True,
53
- torch_dtype=torch.float16
54
- ).to("cuda").eval()
55
-
56
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
57
- """Convert text to speech using Edge TTS and save as MP3"""
58
- communicate = edge_tts.Communicate(text, voice)
59
- await communicate.save(output_file)
60
- return output_file
61
-
62
- def clean_chat_history(chat_history):
63
- """
64
- Filter out any chat entries whose "content" is not a string.
65
- This helps prevent errors when concatenating previous messages.
66
- """
67
- cleaned = []
68
- for msg in chat_history:
69
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
70
- cleaned.append(msg)
71
- return cleaned
72
-
73
- # Environment variables and parameters for Stable Diffusion XL
74
- # Use : SG161222/RealVisXL_V4.0_Lightning or SG161222/RealVisXL_V5.0_Lightning
75
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
76
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
77
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
78
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
79
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
80
-
81
- # Load the SDXL pipeline
82
- sd_pipe = StableDiffusionXLPipeline.from_pretrained(
83
- MODEL_ID_SD,
84
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
85
- use_safetensors=True,
86
- add_watermarker=False,
87
- ).to(device)
88
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
89
-
90
- # Ensure that the text encoder is in half-precision if using CUDA.
91
- if torch.cuda.is_available():
92
- sd_pipe.text_encoder = sd_pipe.text_encoder.half()
93
-
94
- # Optional: compile the model for speedup if enabled
95
- if USE_TORCH_COMPILE:
96
- sd_pipe.compile()
97
-
98
- # Optional: offload parts of the model to CPU if needed
99
- if ENABLE_CPU_OFFLOAD:
100
- sd_pipe.enable_model_cpu_offload()
101
 
102
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- def save_image(img: Image.Image) -> str:
105
- """Save a PIL image with a unique filename and return the path."""
106
- unique_name = str(uuid.uuid4()) + ".png"
107
- img.save(unique_name)
108
- return unique_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
111
  if randomize_seed:
112
  seed = random.randint(0, MAX_SEED)
113
- return seed
114
-
115
- def progress_bar_html(label: str) -> str:
116
- """
117
- Returns an HTML snippet for a thin progress bar with a label.
118
- The progress bar is styled as a dark red animated bar.
119
- """
120
- return f'''
121
- <div style="display: flex; align-items: center;">
122
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
123
- <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
124
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
125
- </div>
126
- </div>
127
- <style>
128
- @keyframes loading {{
129
- 0% {{ transform: translateX(-100%); }}
130
- 100% {{ transform: translateX(100%); }}
131
- }}
132
- </style>
133
- '''
134
-
135
- def downsample_video(video_path):
136
- """
137
- Downsamples the video to 10 evenly spaced frames.
138
- Each frame is returned as a PIL image along with its timestamp.
139
- """
140
- vidcap = cv2.VideoCapture(video_path)
141
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
142
- fps = vidcap.get(cv2.CAP_PROP_FPS)
143
- frames = []
144
- # Sample 10 evenly spaced frames.
145
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
146
- for i in frame_indices:
147
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
148
- success, image = vidcap.read()
149
- if success:
150
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
151
- pil_image = Image.fromarray(image)
152
- timestamp = round(i / fps, 2)
153
- frames.append((pil_image, timestamp))
154
- vidcap.release()
155
- return frames
156
-
157
- @spaces.GPU(duration=60, enable_queue=True)
158
- def generate_image_fn(
159
- prompt: str,
160
- negative_prompt: str = "",
161
- use_negative_prompt: bool = False,
162
- seed: int = 1,
163
- width: int = 1024,
164
- height: int = 1024,
165
- guidance_scale: float = 3,
166
- num_inference_steps: int = 25,
167
- randomize_seed: bool = False,
168
- use_resolution_binning: bool = True,
169
- num_images: int = 1,
170
- progress=gr.Progress(track_tqdm=True),
171
- ):
172
- """Generate images using the SDXL pipeline."""
173
- seed = int(randomize_seed_fn(seed, randomize_seed))
174
- generator = torch.Generator(device=device).manual_seed(seed)
175
 
176
  options = {
177
- "prompt": [prompt] * num_images,
178
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
179
- "width": width,
180
- "height": height,
181
  "guidance_scale": guidance_scale,
182
  "num_inference_steps": num_inference_steps,
 
 
183
  "generator": generator,
184
- "output_type": "pil",
185
  }
186
- if use_resolution_binning:
187
- options["use_resolution_binning"] = True
188
-
189
- images = []
190
- # Process in batches
191
- for i in range(0, num_images, BATCH_SIZE):
192
- batch_options = options.copy()
193
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
194
- if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
195
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
196
- # Wrap the pipeline call in autocast if using CUDA
197
- if device.type == "cuda":
198
- with torch.autocast("cuda", dtype=torch.float16):
199
- outputs = sd_pipe(**batch_options)
200
- else:
201
- outputs = sd_pipe(**batch_options)
202
- images.extend(outputs.images)
203
- image_paths = [save_image(img) for img in images]
204
- return image_paths, seed
205
-
206
- @spaces.GPU
207
- def generate(
208
- input_dict: dict,
209
- chat_history: list[dict],
210
- max_new_tokens: int = 1024,
211
- temperature: float = 0.6,
212
- top_p: float = 0.9,
213
- top_k: int = 50,
214
- repetition_penalty: float = 1.2,
215
- ):
216
- """
217
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
218
- Special commands:
219
- - "@tts1" or "@tts2": triggers text-to-speech.
220
- - "@image": triggers image generation using the SDXL pipeline.
221
- - "@qwen2vl-video": triggers video processing using Qwen2VL.
222
- """
223
- text = input_dict["text"]
224
- files = input_dict.get("files", [])
225
- lower_text = text.strip().lower()
226
-
227
- # Branch for image generation.
228
- if lower_text.startswith("@image"):
229
- # Remove the "@image" tag and use the rest as prompt
230
- prompt = text[len("@image"):].strip()
231
- yield progress_bar_html("Generating Image")
232
- image_paths, used_seed = generate_image_fn(
233
- prompt=prompt,
234
- negative_prompt="",
235
- use_negative_prompt=False,
236
- seed=1,
237
- width=1024,
238
- height=1024,
239
- guidance_scale=3,
240
- num_inference_steps=25,
241
- randomize_seed=True,
242
- use_resolution_binning=True,
243
- num_images=1,
244
- )
245
- yield gr.Image(image_paths[0])
246
- return
247
-
248
- # New branch for video processing with Qwen2VL.
249
- if lower_text.startswith("@video-infer"):
250
- prompt = text[len("@video-infer"):].strip()
251
- if files:
252
- # Assume the first file is a video.
253
- video_path = files[0]
254
- frames = downsample_video(video_path)
255
- messages = [
256
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
257
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
258
- ]
259
- # Append each frame with its timestamp.
260
- for frame in frames:
261
- image, timestamp = frame
262
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
263
- image.save(image_path)
264
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
265
- messages[1]["content"].append({"type": "image", "url": image_path})
266
- else:
267
- messages = [
268
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
269
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
270
- ]
271
- inputs = processor.apply_chat_template(
272
- messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
273
- ).to("cuda")
274
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
275
- generation_kwargs = {
276
- **inputs,
277
- "streamer": streamer,
278
- "max_new_tokens": max_new_tokens,
279
- "do_sample": True,
280
- "temperature": temperature,
281
- "top_p": top_p,
282
- "top_k": top_k,
283
- "repetition_penalty": repetition_penalty,
284
- }
285
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
286
- thread.start()
287
- buffer = ""
288
- yield progress_bar_html("Processing video with Qwen2VL")
289
- for new_text in streamer:
290
- buffer += new_text
291
- buffer = buffer.replace("<|im_end|>", "")
292
- time.sleep(0.01)
293
- yield buffer
294
- return
295
-
296
- # Determine if TTS is requested.
297
- tts_prefix = "@tts"
298
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
299
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
300
-
301
- if is_tts and voice_index:
302
- voice = TTS_VOICES[voice_index - 1]
303
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
304
- conversation = [{"role": "user", "content": text}]
305
- else:
306
- voice = None
307
- text = text.replace(tts_prefix, "").strip()
308
- conversation = clean_chat_history(chat_history)
309
- conversation.append({"role": "user", "content": text})
310
-
311
- if files:
312
- if len(files) > 1:
313
- images = [load_image(image) for image in files]
314
- elif len(files) == 1:
315
- images = [load_image(files[0])]
316
- else:
317
- images = []
318
- messages = [{
319
- "role": "user",
320
- "content": [
321
- *[{"type": "image", "image": image} for image in images],
322
- {"type": "text", "text": text},
323
- ]
324
- }]
325
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
326
- inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
327
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
328
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
329
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
330
- thread.start()
331
- buffer = ""
332
- yield progress_bar_html("Thinking...")
333
- for new_text in streamer:
334
- buffer += new_text
335
- buffer = buffer.replace("<|im_end|>", "")
336
- time.sleep(0.01)
337
- yield buffer
338
- else:
339
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
340
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
341
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
342
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
343
- input_ids = input_ids.to(model.device)
344
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
345
- generation_kwargs = {
346
- "input_ids": input_ids,
347
- "streamer": streamer,
348
- "max_new_tokens": max_new_tokens,
349
- "do_sample": True,
350
- "top_p": top_p,
351
- "top_k": top_k,
352
- "temperature": temperature,
353
- "num_beams": 1,
354
- "repetition_penalty": repetition_penalty,
355
- }
356
- t = Thread(target=model.generate, kwargs=generation_kwargs)
357
- t.start()
358
- outputs = []
359
- yield progress_bar_html("Processing...")
360
- for new_text in streamer:
361
- outputs.append(new_text)
362
- yield "".join(outputs)
363
- final_response = "".join(outputs)
364
- yield final_response
365
- if is_tts and voice:
366
- output_file = asyncio.run(text_to_speech(final_response, voice))
367
- yield gr.Audio(output_file, autoplay=True)
368
-
369
- demo = gr.ChatInterface(
370
- fn=generate,
371
- additional_inputs=[
372
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
373
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
374
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
375
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
376
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
377
- ],
378
- examples=[
379
- ["Write the Python Program for Array Rotation"],
380
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
381
- [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
382
- [{"text": "@video-infer Summarize the event in video", "files": ["examples/sky.mp4"]}],
383
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
384
- ["@image Chocolate dripping from a donut"],
385
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
386
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
387
- ["@tts2 What causes rainbows to form?"],
388
- ],
389
- cache_examples=False,
390
- type="messages",
391
- description="# **QwQ Edge `@video-infer 'prompt..', @image, @tts1`**",
392
- fill_height=True,
393
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ @tts1, @tts2-voices, @image for image gen, @video-infer for video, default [text, vision]"),
394
- stop_btn="Stop Generation",
395
- multimodal=True,
396
- )
397
 
398
  if __name__ == "__main__":
399
- demo.queue(max_size=20).launch(share=True)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
 
3
  import numpy as np
4
+ import random
5
+ from diffusers import DiffusionPipeline
6
+ import torch
7
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "stabilityai/stable-diffusion-3.5-large-turbo"
11
+
12
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
+
14
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe = pipe.to(device)
16
+
17
+ pipe.load_lora_weights("prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA", weight_name="SD3.5-Turbo-Realism-2.0-LoRA.safetensors")
18
+ trigger_word = "Turbo Realism"
19
+ pipe.fuse_lora(lora_scale=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
+ MAX_IMAGE_SIZE = 1024
23
+
24
+ # Define styles
25
+ style_list = [
26
+ {
27
+ "name": "3840 x 2160",
28
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
29
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
30
+ },
31
+ {
32
+ "name": "2560 x 1440",
33
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
34
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
35
+ },
36
+ {
37
+ "name": "HD+",
38
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
39
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
40
+ },
41
+ {
42
+ "name": "Style Zero",
43
+ "prompt": "{prompt}",
44
+ "negative_prompt": "",
45
+ },
46
+ ]
47
 
48
+ STYLE_NAMES = [style["name"] for style in style_list]
49
+ DEFAULT_STYLE_NAME = STYLE_NAMES[0]
50
+
51
+ grid_sizes = {
52
+ "2x1": (2, 1),
53
+ "1x2": (1, 2),
54
+ "2x2": (2, 2),
55
+ "2x3": (2, 3),
56
+ "3x2": (3, 2),
57
+ "1x1": (1, 1)
58
+ }
59
+
60
+ @spaces.GPU(duration=60)
61
+ def infer(
62
+ prompt,
63
+ negative_prompt="",
64
+ seed=42,
65
+ randomize_seed=False,
66
+ width=1024,
67
+ height=1024,
68
+ guidance_scale=7.5,
69
+ num_inference_steps=10,
70
+ style="Style Zero",
71
+ grid_size="1x1",
72
+ progress=gr.Progress(track_tqdm=True),
73
+ ):
74
+ selected_style = next(s for s in style_list if s["name"] == style)
75
+ styled_prompt = selected_style["prompt"].format(prompt=prompt)
76
+ styled_negative_prompt = selected_style["negative_prompt"]
77
 
 
78
  if randomize_seed:
79
  seed = random.randint(0, MAX_SEED)
80
+
81
+ generator = torch.Generator().manual_seed(seed)
82
+
83
+ grid_size_x, grid_size_y = grid_sizes.get(grid_size, (1, 1))
84
+ num_images = grid_size_x * grid_size_y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  options = {
87
+ "prompt": styled_prompt,
88
+ "negative_prompt": styled_negative_prompt,
 
 
89
  "guidance_scale": guidance_scale,
90
  "num_inference_steps": num_inference_steps,
91
+ "width": width,
92
+ "height": height,
93
  "generator": generator,
94
+ "num_images_per_prompt": num_images,
95
  }
96
+
97
+ torch.cuda.empty_cache() # Clear GPU memory
98
+ result = pipe(**options)
99
+
100
+ grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
101
+
102
+ for i, img in enumerate(result.images[:num_images]):
103
+ grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
104
+
105
+ return grid_img, seed
106
+
107
+ examples = [
108
+ "A tiny astronaut hatching from an egg on the moon, 4k, planet theme",
109
+ "An anime-style illustration of a delicious, golden-brown wiener schnitzel on a plate, served with fresh lemon slices, parsley --style raw5",
110
+ "Cold coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4K, Photo-Realistic",
111
+ "A cat holding a sign that says hello world --ar 85:128 --v 6.0 --style raw"
112
+ ]
113
+
114
+ css = '''
115
+ .gradio-container{max-width: 585px !important}
116
+ h1{text-align:center}
117
+ footer {
118
+ visibility: hidden
119
+ }
120
+ '''
121
+
122
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
123
+ with gr.Column(elem_id="col-container"):
124
+ gr.Markdown("## GRID 6X🪨")
125
+
126
+ with gr.Row():
127
+ prompt = gr.Text(
128
+ label="Prompt",
129
+ show_label=False,
130
+ max_lines=1,
131
+ placeholder="Enter your prompt",
132
+ container=False,
133
+ )
134
+
135
+ run_button = gr.Button("Run", scale=0, variant="primary")
136
+
137
+ result = gr.Image(label="Result", show_label=False)
138
+
139
+
140
+ with gr.Row(visible=True):
141
+ grid_size_selection = gr.Dropdown(
142
+ choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
143
+ value="1x1",
144
+ label="Grid Size"
145
+ )
146
+
147
+ with gr.Accordion("Advanced Settings", open=False):
148
+ negative_prompt = gr.Text(
149
+ label="Negative prompt",
150
+ max_lines=1,
151
+ placeholder="Enter a negative prompt",
152
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
153
+ visible=False,
154
+ )
155
+
156
+ seed = gr.Slider(
157
+ label="Seed",
158
+ minimum=0,
159
+ maximum=MAX_SEED,
160
+ step=1,
161
+ value=0,
162
+ )
163
+
164
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
165
+
166
+ with gr.Row():
167
+ width = gr.Slider(
168
+ label="Width",
169
+ minimum=512,
170
+ maximum=MAX_IMAGE_SIZE,
171
+ step=32,
172
+ value=1024,
173
+ )
174
+
175
+ height = gr.Slider(
176
+ label="Height",
177
+ minimum=512,
178
+ maximum=MAX_IMAGE_SIZE,
179
+ step=32,
180
+ value=1024,
181
+ )
182
+
183
+ with gr.Row():
184
+ guidance_scale = gr.Slider(
185
+ label="Guidance scale",
186
+ minimum=0.0,
187
+ maximum=7.5,
188
+ step=0.1,
189
+ value=0.0,
190
+ )
191
+
192
+ num_inference_steps = gr.Slider(
193
+ label="Number of inference steps",
194
+ minimum=1,
195
+ maximum=50,
196
+ step=1,
197
+ value=8,
198
+ )
199
+
200
+ style_selection = gr.Radio(
201
+ show_label=True,
202
+ container=True,
203
+ interactive=True,
204
+ choices=STYLE_NAMES,
205
+ value=DEFAULT_STYLE_NAME,
206
+ label="Quality Style",
207
+ )
208
+
209
+ gr.Examples(examples=examples,
210
+ inputs=[prompt],
211
+ outputs=[result, seed],
212
+ fn=infer,
213
+ cache_examples=False)
214
+
215
+ gr.on(
216
+ triggers=[run_button.click, prompt.submit],
217
+ fn=infer,
218
+ inputs=[
219
+ prompt,
220
+ negative_prompt,
221
+ seed,
222
+ randomize_seed,
223
+ width,
224
+ height,
225
+ guidance_scale,
226
+ num_inference_steps,
227
+ style_selection,
228
+ grid_size_selection,
229
+ ],
230
+ outputs=[result, seed],
231
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  if __name__ == "__main__":
234
+ demo.launch()