p1atdev commited on
Commit
a4bee0b
·
1 Parent(s): 78612b5

fix: no pipe torch.compile

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -31,12 +31,15 @@ dart = AutoModelForCausalLM.from_pretrained(
31
  DART_V3_REPO_ID,
32
  torch_dtype=torch_dtype,
33
  token=HF_TOKEN,
 
34
  )
 
 
 
35
  tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
36
 
37
  pipe = DiffusionPipeline.from_pretrained(IMAGE_MODEL_REPO_ID, torch_dtype=torch_dtype)
38
  pipe = pipe.to(device)
39
- pipe = torch.compile(pipe)
40
 
41
 
42
  MAX_SEED = np.iinfo(np.int32).max
@@ -60,12 +63,15 @@ TEMPLATE = (
60
  @torch.inference_mode
61
  def generate_prompt(aspect_ratio: str):
62
  input_ids = tokenizer.encode_plus(
63
- TEMPLATE.format(aspect_ratio=aspect_ratio)
 
64
  ).input_ids
 
65
 
66
  output_ids = dart.generate(
67
  input_ids,
68
  max_new_tokens=256,
 
69
  temperature=1.0,
70
  top_p=1.0,
71
  top_k=100,
@@ -73,13 +79,35 @@ def generate_prompt(aspect_ratio: str):
73
  )[0]
74
 
75
  generated = output_ids[len(input_ids) :]
76
- decoded = ", ".join(tokenizer.batch_decode(generated))
 
77
 
78
  return decoded
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- @spaces.GPU # [uncomment to use ZeroGPU]
82
- def infer(
83
  negative_prompt: str,
84
  seed,
85
  randomize_seed,
@@ -87,25 +115,23 @@ def infer(
87
  height,
88
  guidance_scale,
89
  num_inference_steps,
90
- progress=gr.Progress(track_tqdm=True),
91
  ):
92
  if randomize_seed:
93
  seed = random.randint(0, MAX_SEED)
94
-
95
  generator = torch.Generator().manual_seed(seed)
96
 
97
  prompt = generate_prompt("<|aspect_ratio:square|>")
98
  print(prompt)
99
 
100
- image = pipe(
101
- prompt=prompt,
102
- negative_prompt=negative_prompt,
103
- guidance_scale=guidance_scale,
104
- num_inference_steps=num_inference_steps,
105
- width=width,
106
- height=height,
107
- generator=generator,
108
- ).images[0]
109
 
110
  return image, prompt, seed
111
 
@@ -186,7 +212,7 @@ with gr.Blocks(css=css) as demo:
186
 
187
  gr.on(
188
  triggers=[run_button.click],
189
- fn=infer,
190
  inputs=[
191
  negative_prompt,
192
  seed,
 
31
  DART_V3_REPO_ID,
32
  torch_dtype=torch_dtype,
33
  token=HF_TOKEN,
34
+ use_cache=True,
35
  )
36
+ dart = dart.eval()
37
+ dart = dart.requires_grad_(False)
38
+ dart = torch.compile(dart)
39
  tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
40
 
41
  pipe = DiffusionPipeline.from_pretrained(IMAGE_MODEL_REPO_ID, torch_dtype=torch_dtype)
42
  pipe = pipe.to(device)
 
43
 
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
 
63
  @torch.inference_mode
64
  def generate_prompt(aspect_ratio: str):
65
  input_ids = tokenizer.encode_plus(
66
+ TEMPLATE.format(aspect_ratio=aspect_ratio),
67
+ return_tensors="pt",
68
  ).input_ids
69
+ print("input_ids", input_ids)
70
 
71
  output_ids = dart.generate(
72
  input_ids,
73
  max_new_tokens=256,
74
+ do_sample=True,
75
  temperature=1.0,
76
  top_p=1.0,
77
  top_k=100,
 
79
  )[0]
80
 
81
  generated = output_ids[len(input_ids) :]
82
+ decoded = ", ".join([token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != ""])
83
+ print("decoded", decoded)
84
 
85
  return decoded
86
 
87
+ @spaces.GPU
88
+ def generate_image(
89
+ prompt: str,
90
+ negative_prompt: str,
91
+ generator,
92
+ width: int,
93
+ height: int,
94
+ guidance_scale: float,
95
+ num_inference_steps: int,
96
+ progress=gr.Progress(track_tqdm=True),
97
+ ):
98
+ image = pipe(
99
+ prompt=prompt,
100
+ negative_prompt=negative_prompt,
101
+ guidance_scale=guidance_scale,
102
+ num_inference_steps=num_inference_steps,
103
+ width=width,
104
+ height=height,
105
+ generator=generator,
106
+ ).images[0]
107
+
108
+ return image
109
 
110
+ def on_generate(
 
111
  negative_prompt: str,
112
  seed,
113
  randomize_seed,
 
115
  height,
116
  guidance_scale,
117
  num_inference_steps,
 
118
  ):
119
  if randomize_seed:
120
  seed = random.randint(0, MAX_SEED)
 
121
  generator = torch.Generator().manual_seed(seed)
122
 
123
  prompt = generate_prompt("<|aspect_ratio:square|>")
124
  print(prompt)
125
 
126
+ image = generate_image(
127
+ prompt,
128
+ negative_prompt,
129
+ generator,
130
+ width,
131
+ height,
132
+ guidance_scale,
133
+ num_inference_steps,
134
+ )
135
 
136
  return image, prompt, seed
137
 
 
212
 
213
  gr.on(
214
  triggers=[run_button.click],
215
+ fn=on_generate,
216
  inputs=[
217
  negative_prompt,
218
  seed,