hysts HF staff commited on
Commit
2483fe6
1 Parent(s): 130100d

Use simple adapter names

Browse files
Files changed (2) hide show
  1. app_base.py +6 -6
  2. model.py +30 -29
app_base.py CHANGED
@@ -69,35 +69,35 @@ def create_demo(model: Model) -> gr.Blocks:
69
  [
70
  "assets/org_canny.jpg",
71
  "Mystical fairy in real, magic, 4k picture, high quality",
72
- "TencentARC/t2i-adapter-canny-sdxl-1.0",
73
  0,
74
  True,
75
  ],
76
  [
77
  "assets/org_sketch.png",
78
  "a robot, mount fuji in the background, 4k photo, highly detailed",
79
- "TencentARC/t2i-adapter-sketch-sdxl-1.0",
80
  0,
81
  True,
82
  ],
83
  [
84
  "assets/org_lin.jpg",
85
  "Ice dragon roar, 4k photo",
86
- "TencentARC/t2i-adapter-lineart-sdxl-1.0",
87
  0,
88
  True,
89
  ],
90
  [
91
  "assets/org_mid.jpg",
92
  "A photo of a room, 4k photo, highly detailed",
93
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
94
  0,
95
  True,
96
  ],
97
  [
98
  "assets/org_zoe.jpg",
99
  "A photo of a orchid, 4k photo, highly detailed",
100
- "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
101
  0,
102
  True,
103
  ],
@@ -109,7 +109,7 @@ def create_demo(model: Model) -> gr.Blocks:
109
  with gr.Group():
110
  image = gr.Image(label="Input image", type="pil", height=600)
111
  prompt = gr.Textbox(label="Prompt")
112
- adapter_name = gr.Dropdown(label="Adapter", choices=ADAPTER_NAMES, value=ADAPTER_NAMES[0])
113
  run_button = gr.Button("Run")
114
  with gr.Accordion("Advanced options", open=False):
115
  apply_preprocess = gr.Checkbox(label="Apply preprocess", value=True)
 
69
  [
70
  "assets/org_canny.jpg",
71
  "Mystical fairy in real, magic, 4k picture, high quality",
72
+ "canny",
73
  0,
74
  True,
75
  ],
76
  [
77
  "assets/org_sketch.png",
78
  "a robot, mount fuji in the background, 4k photo, highly detailed",
79
+ "sketch",
80
  0,
81
  True,
82
  ],
83
  [
84
  "assets/org_lin.jpg",
85
  "Ice dragon roar, 4k photo",
86
+ "lineart",
87
  0,
88
  True,
89
  ],
90
  [
91
  "assets/org_mid.jpg",
92
  "A photo of a room, 4k photo, highly detailed",
93
+ "depth-midas",
94
  0,
95
  True,
96
  ],
97
  [
98
  "assets/org_zoe.jpg",
99
  "A photo of a orchid, 4k photo, highly detailed",
100
+ "depth-zoe",
101
  0,
102
  True,
103
  ],
 
109
  with gr.Group():
110
  image = gr.Image(label="Input image", type="pil", height=600)
111
  prompt = gr.Textbox(label="Prompt")
112
+ adapter_name = gr.Dropdown(label="Adapter name", choices=ADAPTER_NAMES, value=ADAPTER_NAMES[0])
113
  run_button = gr.Button("Run")
114
  with gr.Accordion("Advanced options", open=False):
115
  apply_preprocess = gr.Checkbox(label="Apply preprocess", value=True)
model.py CHANGED
@@ -77,14 +77,15 @@ def resize_to_closest_aspect_ratio(image: PIL.Image.Image) -> PIL.Image.Image:
77
  return resized_image
78
 
79
 
80
- ADAPTER_NAMES = [
81
- "TencentARC/t2i-adapter-canny-sdxl-1.0",
82
- "TencentARC/t2i-adapter-sketch-sdxl-1.0",
83
- "TencentARC/t2i-adapter-lineart-sdxl-1.0",
84
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
85
- "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
86
- # "TencentARC/t2i-adapter-recolor-sdxl-1.0",
87
- ]
 
88
 
89
 
90
  class Preprocessor(ABC):
@@ -169,12 +170,12 @@ PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_ME
169
  if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
170
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171
  preprocessors_gpu: dict[str, Preprocessor] = {
172
- "TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor().to(device),
173
- "TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor().to(device),
174
- "TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor().to(device),
175
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor().to(device),
176
- "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor().to(device),
177
- "TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor().to(device),
178
  }
179
 
180
  def get_preprocessor(adapter_name: str) -> Preprocessor:
@@ -182,12 +183,12 @@ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
182
 
183
  elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
184
  preprocessors_cpu: dict[str, Preprocessor] = {
185
- "TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor(),
186
- "TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor(),
187
- "TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor(),
188
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor(),
189
- "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor(),
190
- "TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor(),
191
  }
192
 
193
  def get_preprocessor(adapter_name: str) -> Preprocessor:
@@ -196,17 +197,17 @@ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
196
  else:
197
 
198
  def get_preprocessor(adapter_name: str) -> Preprocessor:
199
- if adapter_name == "TencentARC/t2i-adapter-canny-sdxl-1.0":
200
  return CannyPreprocessor()
201
- elif adapter_name == "TencentARC/t2i-adapter-sketch-sdxl-1.0":
202
  return PidiNetPreprocessor()
203
- elif adapter_name == "TencentARC/t2i-adapter-lineart-sdxl-1.0":
204
  return LineartPreprocessor()
205
- elif adapter_name == "TencentARC/t2i-adapter-depth-midas-sdxl-1.0":
206
  return MidasPreprocessor()
207
- elif adapter_name == "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0":
208
  return ZoePreprocessor()
209
- elif adapter_name == "TencentARC/t2i-adapter-recolor-sdxl-1.0":
210
  return RecolorPreprocessor()
211
  else:
212
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
@@ -222,7 +223,7 @@ else:
222
  def download_all_adapters():
223
  for adapter_name in ADAPTER_NAMES:
224
  T2IAdapter.from_pretrained(
225
- adapter_name,
226
  torch_dtype=torch.float16,
227
  varient="fp16",
228
  )
@@ -248,7 +249,7 @@ class Model:
248
 
249
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
250
  adapter = T2IAdapter.from_pretrained(
251
- adapter_name,
252
  torch_dtype=torch.float16,
253
  varient="fp16",
254
  ).to(self.device)
@@ -292,7 +293,7 @@ class Model:
292
  if adapter_name == self.adapter_name:
293
  return
294
  self.pipe.adapter = T2IAdapter.from_pretrained(
295
- adapter_name,
296
  torch_dtype=torch.float16,
297
  varient="fp16",
298
  ).to(self.device)
 
77
  return resized_image
78
 
79
 
80
+ ADAPTER_REPO_IDS = {
81
+ "canny": "TencentARC/t2i-adapter-canny-sdxl-1.0",
82
+ "sketch": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
83
+ "lineart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
84
+ "depth-midas": "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
85
+ "depth-zoe": "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
86
+ # "recolor": "TencentARC/t2i-adapter-recolor-sdxl-1.0",
87
+ }
88
+ ADAPTER_NAMES = list(ADAPTER_REPO_IDS.keys())
89
 
90
 
91
  class Preprocessor(ABC):
 
170
  if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
171
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
  preprocessors_gpu: dict[str, Preprocessor] = {
173
+ "canny": CannyPreprocessor().to(device),
174
+ "sketch": PidiNetPreprocessor().to(device),
175
+ "lineart": LineartPreprocessor().to(device),
176
+ "depth-midas": MidasPreprocessor().to(device),
177
+ "depth-zoe": ZoePreprocessor().to(device),
178
+ "recolor": RecolorPreprocessor().to(device),
179
  }
180
 
181
  def get_preprocessor(adapter_name: str) -> Preprocessor:
 
183
 
184
  elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
185
  preprocessors_cpu: dict[str, Preprocessor] = {
186
+ "canny": CannyPreprocessor(),
187
+ "sketch": PidiNetPreprocessor(),
188
+ "lineart": LineartPreprocessor(),
189
+ "depth-midas": MidasPreprocessor(),
190
+ "depth-zoe": ZoePreprocessor(),
191
+ "recolor": RecolorPreprocessor(),
192
  }
193
 
194
  def get_preprocessor(adapter_name: str) -> Preprocessor:
 
197
  else:
198
 
199
  def get_preprocessor(adapter_name: str) -> Preprocessor:
200
+ if adapter_name == "canny":
201
  return CannyPreprocessor()
202
+ elif adapter_name == "sketch":
203
  return PidiNetPreprocessor()
204
+ elif adapter_name == "lineart":
205
  return LineartPreprocessor()
206
+ elif adapter_name == "depth-midas":
207
  return MidasPreprocessor()
208
+ elif adapter_name == "depth-zoe":
209
  return ZoePreprocessor()
210
+ elif adapter_name == "recolor":
211
  return RecolorPreprocessor()
212
  else:
213
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
 
223
  def download_all_adapters():
224
  for adapter_name in ADAPTER_NAMES:
225
  T2IAdapter.from_pretrained(
226
+ ADAPTER_REPO_IDS[adapter_name],
227
  torch_dtype=torch.float16,
228
  varient="fp16",
229
  )
 
249
 
250
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
251
  adapter = T2IAdapter.from_pretrained(
252
+ ADAPTER_REPO_IDS[adapter_name],
253
  torch_dtype=torch.float16,
254
  varient="fp16",
255
  ).to(self.device)
 
293
  if adapter_name == self.adapter_name:
294
  return
295
  self.pipe.adapter = T2IAdapter.from_pretrained(
296
+ ADAPTER_REPO_IDS[adapter_name],
297
  torch_dtype=torch.float16,
298
  varient="fp16",
299
  ).to(self.device)