ducha-aiki commited on
Commit
818a15e
1 Parent(s): 2db4e66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -12,7 +12,9 @@ def geometry_transform(images: list,
12
 
13
  file_names: list = [f.name for f in images]
14
  image_list: list = [K.io.load_image(f, K.io.ImageLoadType(0)).float().unsqueeze(0)/255 for f in file_names]
15
- image_batch: torch.Tensor = torch.cat(image_list, 0)
 
 
16
  center: torch.Tensor = torch.tensor([x.shape[1:] for x in image_batch])/2
17
  translation = torch.tensor(translation).repeat(len(image_list), 2)
18
  scale = torch.tensor(scale).repeat(len(image_list), 2)
@@ -20,7 +22,7 @@ def geometry_transform(images: list,
20
  affine_matrix: torch.Tensor = KG.get_affine_matrix2d(translation, center, scale, angle)
21
  with torch.inference_mode():
22
  transformed: torch.Tensor = KG.transform.warp_affine(image_batch, affine_matrix[:, :2], dsize=image_batch.shape[2:])
23
- concat_images: list = torch.cat(transformed, dim=-1)
24
  final_images: np.ndarray = K.tensor_to_image(concat_images*255).astype(np.uint8)
25
 
26
  return final_images
 
12
 
13
  file_names: list = [f.name for f in images]
14
  image_list: list = [K.io.load_image(f, K.io.ImageLoadType(0)).float().unsqueeze(0)/255 for f in file_names]
15
+ if len(image_list) > 1:
16
+ image_list = [K.geometry.resize(x, x.shape[-2:], antialias=True) for x in image_list]
17
+ image_batch: torch.Tensor = torch.cat(image_list, 0)
18
  center: torch.Tensor = torch.tensor([x.shape[1:] for x in image_batch])/2
19
  translation = torch.tensor(translation).repeat(len(image_list), 2)
20
  scale = torch.tensor(scale).repeat(len(image_list), 2)
 
22
  affine_matrix: torch.Tensor = KG.get_affine_matrix2d(translation, center, scale, angle)
23
  with torch.inference_mode():
24
  transformed: torch.Tensor = KG.transform.warp_affine(image_batch, affine_matrix[:, :2], dsize=image_batch.shape[2:])
25
+ concat_images: list = torch.cat([x for x in transformed], dim=-1)
26
  final_images: np.ndarray = K.tensor_to_image(concat_images*255).astype(np.uint8)
27
 
28
  return final_images