fffiloni commited on
Commit
c1fff88
1 Parent(s): 39564a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -194,14 +194,19 @@ def infer(style_description, ref_style_file, caption):
194
  sampled = models_b.stage_a.decode(sampled_b).float()
195
 
196
  sampled = torch.cat([
197
- torch.nn.functional.interpolate(ref_style.cpu(), size=height),
198
  sampled.cpu(),
199
- ],
200
- dim=0)
201
 
202
- # Save the sampled image to a file
203
- sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
204
- sampled_image.save(output_file) # Save the image
 
 
 
 
 
 
205
 
206
  return output_file # Return the path to the saved image
207
 
 
194
  sampled = models_b.stage_a.decode(sampled_b).float()
195
 
196
  sampled = torch.cat([
197
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
198
  sampled.cpu(),
199
+ ], dim=0)
 
200
 
201
+ # Remove the batch dimension and keep only the generated image
202
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
203
+
204
+ # Ensure the tensor is in [C, H, W] format
205
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
206
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
207
+ sampled_image.save(output_file) # Save the image as a PNG
208
+ else:
209
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
210
 
211
  return output_file # Return the path to the saved image
212