p1atdev commited on
Commit
e47ff0d
·
1 Parent(s): 2bd2f34

chore: load dart model

Browse files
Files changed (1) hide show
  1. app.py +127 -68
app.py CHANGED
@@ -1,51 +1,115 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
  from diffusers import DiffusionPipeline
 
6
  import torch
7
 
 
 
 
 
 
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
13
- else:
14
- torch_dtype = torch.float32
 
 
 
 
 
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
 
 
 
17
  pipe = pipe.to(device)
 
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if randomize_seed:
26
  seed = random.randint(0, MAX_SEED)
27
-
28
  generator = torch.Generator().manual_seed(seed)
29
-
 
 
 
30
  image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
41
-
42
- examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
46
- ]
47
-
48
- css="""
49
  #col-container {
50
  margin: 0 auto;
51
  max-width: 640px;
@@ -53,35 +117,28 @@ css="""
53
  """
54
 
55
  with gr.Blocks(css=css) as demo:
56
-
57
  with gr.Column(elem_id="col-container"):
58
  gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
  """)
61
-
62
  with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
  result = gr.Image(label="Result", show_label=False)
75
 
 
 
 
76
  with gr.Accordion("Advanced Settings", open=False):
77
-
78
  negative_prompt = gr.Text(
79
  label="Negative prompt",
80
  max_lines=1,
81
  placeholder="Enter a negative prompt",
82
  visible=False,
 
83
  )
84
-
85
  seed = gr.Slider(
86
  label="Seed",
87
  minimum=0,
@@ -89,54 +146,56 @@ with gr.Blocks(css=css) as demo:
89
  step=1,
90
  value=0,
91
  )
92
-
93
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
  with gr.Row():
96
-
97
  width = gr.Slider(
98
  label="Width",
99
  minimum=256,
100
  maximum=MAX_IMAGE_SIZE,
101
  step=32,
102
- value=1024, #Replace with defaults that work for your model
103
  )
104
-
105
  height = gr.Slider(
106
  label="Height",
107
  minimum=256,
108
  maximum=MAX_IMAGE_SIZE,
109
  step=32,
110
- value=1024, #Replace with defaults that work for your model
111
  )
112
-
113
  with gr.Row():
114
-
115
  guidance_scale = gr.Slider(
116
  label="Guidance scale",
117
- minimum=0.0,
118
  maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
  )
122
-
123
  num_inference_steps = gr.Slider(
124
  label="Number of inference steps",
125
  minimum=1,
126
  maximum=50,
127
  step=1,
128
- value=2, #Replace with defaults that work for your model
129
  )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
  gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
140
  )
141
 
142
- demo.queue().launch()
 
1
+ import os
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
+ import spaces
6
  from diffusers import DiffusionPipeline
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import torch
9
 
10
+ try:
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+ except:
15
+ print("failed to import dotenv (this is not a problem on the production)")
16
+
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
18
 
19
+ HF_TOKEN = os.environ.get("HF_TOKEN")
20
+ assert HF_TOKEN is not None
21
+
22
+ IMAGE_MODEL_REPO_ID = os.environ.get(
23
+ "IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0"
24
+ )
25
+ DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", "p1atdev/dart-v3-llama-8L-241003")
26
+
27
+ torch_dtype = torch.bfloat16
28
 
29
+ dart = AutoModelForCausalLM.from_pretrained(
30
+ DART_V3_REPO_ID,
31
+ torch_dtype=torch_dtype,
32
+ token=HF_TOKEN,
33
+ )
34
+ tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
35
+
36
+ pipe = DiffusionPipeline.from_pretrained(IMAGE_MODEL_REPO_ID, torch_dtype=torch_dtype)
37
  pipe = pipe.to(device)
38
+ pipe = torch.compile(pipe)
39
+
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1024
43
 
44
+ TEMPLATE = (
45
+ "<|bos|>"
46
+ #
47
+ "<|rating:general|>"
48
+ "{aspect_ratio}"
49
+ "<|length:medium|>"
50
+ #
51
+ "<copyright>original</copyright>"
52
+ #
53
+ "<character></character>"
54
+ #
55
+ "<general>"
56
+ )
57
+
58
+
59
+ @torch.inference_mode
60
+ def generate_prompt(aspect_ratio: str):
61
+ input_ids = tokenizer.encode_plus(
62
+ TEMPLATE.format(aspect_ratio=aspect_ratio)
63
+ ).input_ids
64
 
65
+ output_ids = dart.generate(
66
+ input_ids,
67
+ max_new_tokens=256,
68
+ temperature=1.0,
69
+ top_p=1.0,
70
+ top_k=100,
71
+ num_beams=1,
72
+ )[0]
73
+
74
+ generated = output_ids[len(input_ids) :]
75
+ decoded = ", ".join(tokenizer.batch_decode(generated))
76
+
77
+ return decoded
78
+
79
+
80
+ @spaces.GPU # [uncomment to use ZeroGPU]
81
+ def infer(
82
+ negative_prompt: str,
83
+ seed,
84
+ randomize_seed,
85
+ width,
86
+ height,
87
+ guidance_scale,
88
+ num_inference_steps,
89
+ progress=gr.Progress(track_tqdm=True),
90
+ ):
91
  if randomize_seed:
92
  seed = random.randint(0, MAX_SEED)
93
+
94
  generator = torch.Generator().manual_seed(seed)
95
+
96
+ prompt = generate_prompt("<|aspect_ratio:square|>")
97
+ print(prompt)
98
+
99
  image = pipe(
100
+ prompt=prompt,
101
+ negative_prompt=negative_prompt,
102
+ guidance_scale=guidance_scale,
103
+ num_inference_steps=num_inference_steps,
104
+ width=width,
105
+ height=height,
106
+ generator=generator,
107
+ ).images[0]
108
+
109
+ return image, prompt, seed
110
+
111
+
112
+ css = """
 
 
 
 
 
113
  #col-container {
114
  margin: 0 auto;
115
  max-width: 640px;
 
117
  """
118
 
119
  with gr.Blocks(css=css) as demo:
 
120
  with gr.Column(elem_id="col-container"):
121
  gr.Markdown(f"""
122
+ # Random IllustriousXL
123
  """)
124
+
125
  with gr.Row():
126
+ run_button = gr.Button("Generate random", scale=0)
127
+
 
 
 
 
 
 
 
 
 
128
  result = gr.Image(label="Result", show_label=False)
129
 
130
+ with gr.Accordion("Generation details", open=False):
131
+ prompt_txt = gr.Textbox("Generated prompt", interactive=False)
132
+
133
  with gr.Accordion("Advanced Settings", open=False):
 
134
  negative_prompt = gr.Text(
135
  label="Negative prompt",
136
  max_lines=1,
137
  placeholder="Enter a negative prompt",
138
  visible=False,
139
+ value=" worst quality, comic, multiple views, bad quality, low quality, lowres, displeasing, very displeasing, bad anatomy, bad hands, scan artifacts, monochrome, greyscale, signature, twitter username, jpeg artifacts, 2koma, 4koma, guro, extra digits, fewer digits",
140
  )
141
+
142
  seed = gr.Slider(
143
  label="Seed",
144
  minimum=0,
 
146
  step=1,
147
  value=0,
148
  )
149
+
150
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
151
+
152
  with gr.Row():
 
153
  width = gr.Slider(
154
  label="Width",
155
  minimum=256,
156
  maximum=MAX_IMAGE_SIZE,
157
  step=32,
158
+ value=1024, # Replace with defaults that work for your model
159
  )
160
+
161
  height = gr.Slider(
162
  label="Height",
163
  minimum=256,
164
  maximum=MAX_IMAGE_SIZE,
165
  step=32,
166
+ value=1024, # Replace with defaults that work for your model
167
  )
168
+
169
  with gr.Row():
 
170
  guidance_scale = gr.Slider(
171
  label="Guidance scale",
172
+ minimum=1.0,
173
  maximum=10.0,
174
+ step=0.5,
175
+ value=6.5,
176
  )
177
+
178
  num_inference_steps = gr.Slider(
179
  label="Number of inference steps",
180
  minimum=1,
181
  maximum=50,
182
  step=1,
183
+ value=20,
184
  )
185
+
 
 
 
 
186
  gr.on(
187
+ triggers=[run_button.click],
188
+ fn=infer,
189
+ inputs=[
190
+ negative_prompt,
191
+ seed,
192
+ randomize_seed,
193
+ width,
194
+ height,
195
+ guidance_scale,
196
+ num_inference_steps,
197
+ ],
198
+ outputs=[result, prompt_txt, seed],
199
  )
200
 
201
+ demo.queue().launch()