tight-inversion commited on
Commit
10d3d92
·
1 Parent(s): 577e44b

Fix encoding

Browse files
Files changed (1) hide show
  1. app.py +26 -44
app.py CHANGED
@@ -28,7 +28,6 @@ def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
28
  ae = load_ae(name, device=device)
29
  return model, ae, t5, clip
30
 
31
-
32
  class FluxGenerator:
33
  def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
34
  self.device = torch.device(device)
@@ -44,47 +43,7 @@ class FluxGenerator:
44
  self.pulid_model = PuLIDPipeline(self.model, device='cuda', weight_dtype=torch.bfloat16)
45
  self.pulid_model.load_pretrain(args.pretrained_model)
46
 
47
- # function to encode an image into latents
48
- def encode_image_to_latents(self, img, opts):
49
- """
50
- Opposite of decode: Takes a PIL image and encodes it into latents (x).
51
- """
52
- t0 = time.perf_counter()
53
-
54
- # Resize if necessary, or use opts.height / opts.width if you want a fixed size:
55
- img = img.resize((opts.width, opts.height), resample=Image.LANCZOS)
56
-
57
- # Convert image to torch.Tensor and scale to [-1, 1]
58
- # Image is in [0, 255] → scale to [0,1] → then map to [-1,1].
59
- x = np.array(img).astype(np.float32)
60
- x = torch.from_numpy(x) # shape: (H, W, C)
61
- x = (x / 127.5) - 1.0 # now in [-1, 1]
62
- x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W)
63
-
64
- # Move encoder to device if you are offloading
65
- if self.offload:
66
- self.ae.encoder.to(self.device)
67
-
68
- x = x.to(self.device)
69
-
70
- # 2) Encode with autocast
71
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
72
- x = self.ae.encode(x)
73
-
74
- x = x.to(torch.bfloat16)
75
-
76
-
77
- # 3) Offload if needed
78
- if self.offload:
79
- self.ae.encoder.cpu()
80
- torch.cuda.empty_cache()
81
-
82
- t1 = time.perf_counter()
83
- print(f"Encoded in {t1 - t0:.2f} seconds.")
84
-
85
- return x
86
-
87
- @spaces.GPU(duration=90)
88
  @torch.inference_mode()
89
  def generate_image(
90
  self,
@@ -153,8 +112,31 @@ class FluxGenerator:
153
  noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
154
  if noise.shape[0] == 1 and bs > 1:
155
  noise = repeat(noise, "1 ... -> bs ...", bs=bs)
156
- # encode
157
- x = self.encode_image_to_latents(id_image, opts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
160
 
 
28
  ae = load_ae(name, device=device)
29
  return model, ae, t5, clip
30
 
 
31
  class FluxGenerator:
32
  def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
33
  self.device = torch.device(device)
 
43
  self.pulid_model = PuLIDPipeline(self.model, device='cuda', weight_dtype=torch.bfloat16)
44
  self.pulid_model.load_pretrain(args.pretrained_model)
45
 
46
+ @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @torch.inference_mode()
48
  def generate_image(
49
  self,
 
112
  noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
113
  if noise.shape[0] == 1 and bs > 1:
114
  noise = repeat(noise, "1 ... -> bs ...", bs=bs)
115
+ # Encode id_image directly here
116
+ encode_t0 = time.perf_counter()
117
+
118
+ # Resize image
119
+ id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
120
+
121
+ # Convert image to torch.Tensor and scale to [-1, 1]
122
+ x = np.array(id_image).astype(np.float32)
123
+ x = torch.from_numpy(x) # shape: (H, W, C)
124
+ x = (x / 127.5) - 1.0 # now in [-1, 1]
125
+ x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W)
126
+ x = x.to(self.device)
127
+ # Encode with autocast
128
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
129
+ x = self.ae.encode(x)
130
+
131
+ x = x.to(torch.bfloat16)
132
+
133
+ # Offload if needed
134
+ if self.offload:
135
+ self.ae.encoder.to("cpu")
136
+ torch.cuda.empty_cache()
137
+
138
+ encode_t1 = time.perf_counter()
139
+ print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
140
 
141
  timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
142