valhalla commited on
Commit
c1b5cda
1 Parent(s): 54c8dfa

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +61 -1
model.py CHANGED
@@ -18,6 +18,64 @@ from diffusers import (
18
  T2IAdapter,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ADAPTER_NAMES = [
22
  "TencentARC/t2i-adapter-canny-sdxl-1.0",
23
  "TencentARC/t2i-adapter-sketch-sdxl-1.0",
@@ -57,7 +115,7 @@ class LineartPreprocessor(Preprocessor):
57
  return self.model.to(device)
58
 
59
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
60
- return self.model(image, detect_resolution=384, image_resolution=1024)
61
 
62
 
63
  class MidasPreprocessor(Preprocessor):
@@ -273,6 +331,8 @@ class Model:
273
  if apply_preprocess:
274
  image = self.preprocessor(image)
275
 
 
 
276
  generator = torch.Generator(device=self.device).manual_seed(seed)
277
  out = self.pipe(
278
  prompt=prompt,
 
18
  T2IAdapter,
19
  )
20
 
21
+ SD_XL_BASE_RATIOS = {
22
+ "0.5": (704, 1408),
23
+ "0.52": (704, 1344),
24
+ "0.57": (768, 1344),
25
+ "0.6": (768, 1280),
26
+ "0.68": (832, 1216),
27
+ "0.72": (832, 1152),
28
+ "0.78": (896, 1152),
29
+ "0.82": (896, 1088),
30
+ "0.88": (960, 1088),
31
+ "0.94": (960, 1024),
32
+ "1.0": (1024, 1024),
33
+ "1.07": (1024, 960),
34
+ "1.13": (1088, 960),
35
+ "1.21": (1088, 896),
36
+ "1.29": (1152, 896),
37
+ "1.38": (1152, 832),
38
+ "1.46": (1216, 832),
39
+ "1.67": (1280, 768),
40
+ "1.75": (1344, 768),
41
+ "1.91": (1344, 704),
42
+ "2.0": (1408, 704),
43
+ "2.09": (1472, 704),
44
+ "2.4": (1536, 640),
45
+ "2.5": (1600, 640),
46
+ "2.89": (1664, 576),
47
+ "3.0": (1728, 576),
48
+ }
49
+
50
+ def find_closest_aspect_ratio(target_width, target_height):
51
+ target_ratio = target_width / target_height
52
+ closest_ratio = None
53
+ min_difference = float('inf')
54
+
55
+ for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
56
+ ratio = width / height
57
+ difference = abs(target_ratio - ratio)
58
+
59
+ if difference < min_difference:
60
+ min_difference = difference
61
+ closest_ratio = ratio_str
62
+
63
+ return closest_ratio
64
+
65
+
66
+ def resize_to_closest_aspect_ratio(image):
67
+ target_width, target_height = image.size
68
+ closest_ratio = find_closest_aspect_ratio(target_width, target_height)
69
+
70
+ # Get the dimensions from the closest aspect ratio in the dictionary
71
+ new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
72
+
73
+ # Resize the image to the new dimensions while preserving the aspect ratio
74
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
75
+
76
+ return resized_image
77
+
78
+
79
  ADAPTER_NAMES = [
80
  "TencentARC/t2i-adapter-canny-sdxl-1.0",
81
  "TencentARC/t2i-adapter-sketch-sdxl-1.0",
 
115
  return self.model.to(device)
116
 
117
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
118
+ return self.model(image, detect_resolution=512, image_resolution=1024)
119
 
120
 
121
  class MidasPreprocessor(Preprocessor):
 
331
  if apply_preprocess:
332
  image = self.preprocessor(image)
333
 
334
+ image = resize_to_closest_aspect_ratio(image)
335
+
336
  generator = torch.Generator(device=self.device).manual_seed(seed)
337
  out = self.pipe(
338
  prompt=prompt,