Prgckwb commited on
Commit
a8b4b26
โ€ข
1 Parent(s): 7a474e6
Files changed (1) hide show
  1. app.py +59 -24
app.py CHANGED
@@ -5,7 +5,7 @@ import random
5
  import spaces # [uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
8
-
9
 
10
  model_ids = [
11
  "Prgckwb/trpfrog-sd3.5-large",
@@ -22,12 +22,13 @@ else:
22
  pipelines = {
23
  model_id: DiffusionPipeline.from_pretrained(
24
  model_id, torch_dtype=torch_dtype
25
- )
26
  for model_id in model_ids
27
  }
28
 
29
 
30
  @spaces.GPU()
 
31
  def inference(
32
  model_id: str,
33
  prompt: str,
@@ -35,31 +36,65 @@ def inference(
35
  height: int,
36
  progress=gr.Progress(track_tqdm=True),
37
  ):
38
- pipe = pipelines[model_id].to(device)
 
39
 
40
- image = pipe(
41
- prompt=prompt,
42
- width=width,
43
- height=height,
44
- ).images[0]
 
 
 
45
 
46
  return image
47
 
48
-
49
- if __name__ == "__main__":
50
  theme = gr.themes.Ocean()
51
 
52
- demo = gr.Interface(
53
- fn=inference,
54
- inputs=[
55
- gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0]),
56
- gr.Textbox(label="Prompt", placeholder="an icon of trpfrog"),
57
- gr.Slider(label="Width", minimum=64, maximum=1024, step=64, value=1024),
58
- gr.Slider(label="Height", minimum=64, maximum=1024, step=64, value=1024),
59
- ],
60
- outputs=[
61
- gr.Image(label="Output"),
62
- ],
63
- theme=theme,
64
- )
65
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import spaces # [uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
8
+ from PIL import Image
9
 
10
  model_ids = [
11
  "Prgckwb/trpfrog-sd3.5-large",
 
22
  pipelines = {
23
  model_id: DiffusionPipeline.from_pretrained(
24
  model_id, torch_dtype=torch_dtype
25
+ ) if device == 'cuda' else None
26
  for model_id in model_ids
27
  }
28
 
29
 
30
  @spaces.GPU()
31
+ @torch.inference_mode()
32
  def inference(
33
  model_id: str,
34
  prompt: str,
 
36
  height: int,
37
  progress=gr.Progress(track_tqdm=True),
38
  ):
39
+ if device == 'cuda':
40
+ pipe = pipelines[model_id].to(device)
41
 
42
+ image = pipe(
43
+ prompt=prompt,
44
+ width=width,
45
+ height=height,
46
+ ).images[0]
47
+ else:
48
+ # ็œŸใฃ้ป’ใฎ็”ปๅƒใ‚’็”Ÿๆˆ
49
+ image = Image.fromarray(np.random.randn(height, width, 3).astype(np.uint8))
50
 
51
  return image
52
 
53
+ def create_interface():
 
54
  theme = gr.themes.Ocean()
55
 
56
+ with gr.Blocks(theme=theme) as demo:
57
+ with gr.Column():
58
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>TrpFrog Diffusion Demo</h1>")
59
+
60
+ with gr.Row():
61
+ with gr.Column():
62
+ input_model_id = gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0])
63
+ input_prompt = gr.Textbox(label="Prompt", placeholder="an icon of trpfrog", value="an icon of trpfrog")
64
+
65
+ with gr.Row():
66
+ input_width = gr.Slider(label="Width", minimum=64, maximum=2056, step=128, value=1024)
67
+ input_height = gr.Slider(label="Height", minimum=64, maximum=2056, step=128, value=1024)
68
+
69
+ with gr.Row():
70
+ clear_btn = gr.ClearButton(components=[input_prompt])
71
+ submit_btn = gr.Button('Generate', variant='primary')
72
+
73
+ with gr.Column():
74
+ output_image = gr.Image(label="Output")
75
+ all_inputs = [input_model_id, input_prompt, input_width, input_height]
76
+ all_outputs = [output_image]
77
+
78
+ examples = gr.Examples(
79
+ examples=[
80
+ ['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog eating ramen', 1024, 1024],
81
+ ['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog with a gun', 1024, 1024],
82
+ ],
83
+ inputs=all_inputs,
84
+ outputs=all_outputs,
85
+ fn=inference,
86
+ cache_mode='eager',
87
+ cache_examples=True,
88
+ )
89
+
90
+ submit_btn.click(inference, inputs=all_inputs, outputs=all_outputs)
91
+ input_prompt.submit(inference, inputs=all_inputs, outputs=all_outputs)
92
+
93
+ return demo
94
+
95
+ if __name__ == "__main__":
96
+ try:
97
+ demo = create_interface()
98
+ demo.queue().launch()
99
+ except Exception as e:
100
+ raise gr.Error(e)