jamino30 commited on
Commit
89e4ae0
·
verified ·
1 Parent(s): fe13422

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +4 -2
  2. inference.py +2 -1
  3. utils.py +19 -12
app.py CHANGED
@@ -40,7 +40,7 @@ load_model_without_module(sod_model, local_model_path)
40
 
41
  style_files = os.listdir('./style_images')
42
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
43
- lrs = np.logspace(np.log10(0.0025), np.log10(0.25), 10).tolist()
44
  img_size = 512
45
 
46
  cached_style_features = {}
@@ -54,7 +54,8 @@ for style_name, style_img_path in style_options.items():
54
  def run(content_image, style_name, style_strength=10):
55
  yield [None] * 3
56
  content_img, original_size = preprocess_img(content_image, img_size)
57
- content_img = content_img.to(device)
 
58
 
59
  print('-'*15)
60
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
@@ -79,6 +80,7 @@ def run(content_image, style_name, style_strength=10):
79
  model=model,
80
  sod_model=sod_model,
81
  content_image=content_img,
 
82
  style_features=style_features,
83
  lr=lrs[style_strength-1],
84
  apply_to_background=apply_to_background
 
40
 
41
  style_files = os.listdir('./style_images')
42
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
43
+ lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
44
  img_size = 512
45
 
46
  cached_style_features = {}
 
54
  def run(content_image, style_name, style_strength=10):
55
  yield [None] * 3
56
  content_img, original_size = preprocess_img(content_image, img_size)
57
+ content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
58
+ content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
59
 
60
  print('-'*15)
61
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
 
80
  model=model,
81
  sod_model=sod_model,
82
  content_image=content_img,
83
+ content_image_norm=content_img_normalized,
84
  style_features=style_features,
85
  lr=lrs[style_strength-1],
86
  apply_to_background=apply_to_background
inference.py CHANGED
@@ -38,6 +38,7 @@ def inference(
38
  model,
39
  sod_model,
40
  content_image,
 
41
  style_features,
42
  apply_to_background,
43
  lr,
@@ -56,7 +57,7 @@ def inference(
56
  resized_bg_masks = []
57
  salient_object_ratio = None
58
  if apply_to_background:
59
- segmentation_output = sod_model(content_image)[0]
60
  segmentation_output = torch.sigmoid(segmentation_output)
61
  segmentation_mask = (segmentation_output > 0.7).float()
62
  background_mask = (segmentation_mask == 0).float()
 
38
  model,
39
  sod_model,
40
  content_image,
41
+ content_image_norm,
42
  style_features,
43
  apply_to_background,
44
  lr,
 
57
  resized_bg_masks = []
58
  salient_object_ratio = None
59
  if apply_to_background:
60
+ segmentation_output = sod_model(content_image_norm)[0]
61
  segmentation_output = torch.sigmoid(segmentation_output)
62
  segmentation_mask = (segmentation_output > 0.7).float()
63
  background_mask = (segmentation_mask == 0).float()
utils.py CHANGED
@@ -3,28 +3,35 @@ from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
- def preprocess_img_from_path(path_to_image, img_size):
7
  img = Image.open(path_to_image)
8
- return preprocess_img(img, img_size)
9
 
10
- def preprocess_img(img: Image, img_size):
11
  original_size = img.size
12
 
13
- transform = transforms.Compose([
14
- transforms.Resize((img_size, img_size)),
15
- transforms.ToTensor(),
16
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
- ])
 
 
 
 
 
 
18
  img = transform(img).unsqueeze(0)
19
  return img, original_size
20
 
21
- def postprocess_img(img, original_size):
22
  img = img.detach().cpu().squeeze(0)
23
 
24
  # Denormalize the image
25
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
26
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
27
- img = img * std + mean
 
28
  img = torch.clamp(img, 0, 1)
29
 
30
  img = transforms.ToPILImage()(img)
 
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
+ def preprocess_img_from_path(path_to_image, img_size, normalize=False):
7
  img = Image.open(path_to_image)
8
+ return preprocess_img(img, img_size, normalize)
9
 
10
+ def preprocess_img(img: Image, img_size, normalize=False):
11
  original_size = img.size
12
 
13
+ if normalize:
14
+ transform = transforms.Compose([
15
+ transforms.Resize((img_size, img_size)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ ])
19
+ else:
20
+ transform = transforms.Compose([
21
+ transforms.Resize((img_size, img_size)),
22
+ transforms.ToTensor()
23
+ ])
24
  img = transform(img).unsqueeze(0)
25
  return img, original_size
26
 
27
+ def postprocess_img(img, original_size, normalize=False):
28
  img = img.detach().cpu().squeeze(0)
29
 
30
  # Denormalize the image
31
+ if normalize:
32
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
33
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
34
+ img = img * std + mean
35
  img = torch.clamp(img, 0, 1)
36
 
37
  img = transforms.ToPILImage()(img)