gokaygokay commited on
Commit
c2e96d2
1 Parent(s): ecb4b85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import torch
4
+ import random
5
+ from huggingface_hub import snapshot_download
6
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
7
+ from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
11
+
12
+ import subprocess
13
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
+
15
+ # Download the model files
16
+ ckpt_dir = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
17
+
18
+ # Load the models
19
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
20
+
21
+ pipe = StableDiffusionXLPipeline.from_pretrained(
22
+ ckpt_dir,
23
+ vae=vae,
24
+ torch_dtype=torch.float16,
25
+ use_safetensors=True,
26
+ variant="fp16"
27
+ )
28
+ pipe = pipe.to("cuda")
29
+
30
+ # Define samplers
31
+ samplers = {
32
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
33
+ "DPM++ 2M": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True),
34
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
35
+ }
36
+
37
+ DEFAULT_POSITIVE_PREFIX = "score_9, score_8_up, score_7_up, BREAK,"
38
+ DEFAULT_POSITIVE_SUFFIX = "(masterpiece), best quality, very aesthetic, perfect face"
39
+ DEFAULT_NEGATIVE_PREFIX = "score_1, score_2, score_3, text"
40
+ DEFAULT_NEGATIVE_SUFFIX = "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
41
+
42
+ # Initialize Florence model
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
45
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
46
+
47
+ # Prompt Enhancer
48
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
49
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
50
+
51
+ # Florence caption function
52
+ def florence_caption(image):
53
+ # Convert image to PIL if it's not already
54
+ if not isinstance(image, Image.Image):
55
+ image = Image.fromarray(image)
56
+
57
+ inputs = florence_processor(text="<DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
58
+ generated_ids = florence_model.generate(
59
+ input_ids=inputs["input_ids"],
60
+ pixel_values=inputs["pixel_values"],
61
+ max_new_tokens=1024,
62
+ early_stopping=False,
63
+ do_sample=False,
64
+ num_beams=3,
65
+ )
66
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
67
+ parsed_answer = florence_processor.post_process_generation(
68
+ generated_text,
69
+ task="<DETAILED_CAPTION>",
70
+ image_size=(image.width, image.height)
71
+ )
72
+ return parsed_answer["<DETAILED_CAPTION>"]
73
+
74
+ # Prompt Enhancer function
75
+ def enhance_prompt(input_prompt, model_choice):
76
+ if model_choice == "Medium":
77
+ result = enhancer_medium("Enhance the description: " + input_prompt)
78
+ enhanced_text = result[0]['summary_text']
79
+ else: # Long
80
+ result = enhancer_long("Enhance the description: " + input_prompt)
81
+ enhanced_text = result[0]['summary_text']
82
+
83
+ return enhanced_text
84
+
85
+ @spaces.GPU(duration=120)
86
+ def generate_image(additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer, input_image=None, progress=gr.Progress(track_tqdm=True)):
87
+ if use_random_seed:
88
+ seed = random.randint(0, 2**32 - 1)
89
+ else:
90
+ seed = int(seed) # Ensure seed is an integer
91
+
92
+ # Set the scheduler based on the selected sampler
93
+ pipe.scheduler = samplers[sampler]
94
+
95
+ # Set clip skip
96
+ pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
97
+
98
+ # Start with the default positive prompt prefix
99
+ full_positive_prompt = DEFAULT_POSITIVE_PREFIX
100
+
101
+ # Add Florence-2 caption if enabled and image is provided
102
+ if use_florence2 and input_image is not None:
103
+ florence2_caption = florence_caption(input_image)
104
+ florence2_caption = florence2_caption.lower().replace('.', ',')
105
+ additional_positive_prompt = f"{florence2_caption}, {additional_positive_prompt}" if additional_positive_prompt else florence2_caption
106
+
107
+ # Enhance only the additional positive prompt if enhancers are enabled
108
+ if additional_positive_prompt:
109
+ enhanced_prompt = additional_positive_prompt
110
+ if use_medium_enhancer:
111
+ medium_enhanced = enhance_prompt(enhanced_prompt, "Medium")
112
+ medium_enhanced = medium_enhanced.lower().replace('.', ',')
113
+ enhanced_prompt = f"{enhanced_prompt}, {medium_enhanced}"
114
+ if use_long_enhancer:
115
+ long_enhanced = enhance_prompt(enhanced_prompt, "Long")
116
+ long_enhanced = long_enhanced.lower().replace('.', ',')
117
+ enhanced_prompt = f"{enhanced_prompt}, {long_enhanced}"
118
+ full_positive_prompt += f"{enhanced_prompt}"
119
+
120
+ # Add the default positive suffix
121
+ full_positive_prompt += f", {DEFAULT_POSITIVE_SUFFIX}"
122
+
123
+ # Combine default negative prompt with additional negative prompt
124
+ full_negative_prompt = f"{DEFAULT_NEGATIVE_PREFIX}, {additional_negative_prompt}, {DEFAULT_NEGATIVE_SUFFIX}" if additional_negative_prompt else f"{DEFAULT_NEGATIVE_PREFIX}, {DEFAULT_NEGATIVE_SUFFIX}"
125
+
126
+ try:
127
+ image = pipe(
128
+ prompt=full_positive_prompt,
129
+ negative_prompt=full_negative_prompt,
130
+ height=height,
131
+ width=width,
132
+ num_inference_steps=num_inference_steps,
133
+ guidance_scale=guidance_scale,
134
+ num_images_per_prompt=num_images_per_prompt,
135
+ generator=torch.Generator(pipe.device).manual_seed(seed)
136
+ ).images
137
+ return image, seed, full_positive_prompt
138
+ except Exception as e:
139
+ print(f"Error during image generation: {str(e)}")
140
+ return None, seed, full_positive_prompt
141
+
142
+ # Gradio interface
143
+ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
144
+ gr.HTML("""
145
+ <h1 align="center">Pony Realism v21 SDXL - Text-to-Image Generation</h1>
146
+ <p align="center">
147
+ <a href="https://huggingface.co/John6666/pony-realism-v21main-sdxl/" target="_blank">[HF Model Page]</a>
148
+ <a href="https://civitai.com/models/372465/pony-realism" target="_blank">[civitai Model Page]</a>
149
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
150
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
151
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance" target="_blank">[Prompt Enhancer Medium]</a>
152
+ </p>
153
+ """)
154
+
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Add your positive prompt here")
158
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Add your negative prompt here")
159
+
160
+ with gr.Accordion("Advanced settings", open=False):
161
+ height = gr.Slider(512, 2048, 1024, step=64, label="Height")
162
+ width = gr.Slider(512, 2048, 1024, step=64, label="Width")
163
+ num_inference_steps = gr.Slider(20, 50, 30, step=1, label="Number of Inference Steps")
164
+ guidance_scale = gr.Slider(1, 20, 6, step=0.1, label="Guidance Scale")
165
+ num_images_per_prompt = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
166
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
167
+ seed = gr.Number(label="Seed", value=0, precision=0)
168
+ sampler = gr.Dropdown(label="Sampler", choices=list(samplers.keys()), value="DPM++ SDE Karras")
169
+ clip_skip = gr.Slider(1, 4, 2, step=1, label="Clip skip")
170
+
171
+ with gr.Accordion("Captioner and Enhancers", open=False):
172
+ input_image = gr.Image(label="Input Image for Florence-2 Captioner")
173
+ use_florence2 = gr.Checkbox(label="Use Florence-2 Captioner", value=False)
174
+ use_medium_enhancer = gr.Checkbox(label="Use Medium Prompt Enhancer", value=False)
175
+ use_long_enhancer = gr.Checkbox(label="Use Long Prompt Enhancer", value=False)
176
+
177
+ generate_btn = gr.Button("Generate Image")
178
+
179
+ with gr.Column(scale=1):
180
+ output_gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
181
+ seed_used = gr.Number(label="Seed Used")
182
+ full_prompt_used = gr.Textbox(label="Full Positive Prompt Used")
183
+
184
+ generate_btn.click(
185
+ fn=generate_image,
186
+ inputs=[
187
+ positive_prompt, negative_prompt, height, width, num_inference_steps,
188
+ guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
189
+ clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer, input_image
190
+ ],
191
+ outputs=[output_gallery, seed_used, full_prompt_used]
192
+ )
193
+
194
+ demo.launch(debug=True)