jamino30 commited on
Commit
429658f
·
verified ·
1 Parent(s): e9b440b

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +1 -1
  3. u2net/inference.py +3 -2
.gitignore CHANGED
@@ -173,3 +173,6 @@ cython_debug/
173
 
174
  # Tensorboard
175
  /runs
 
 
 
 
173
 
174
  # Tensorboard
175
  /runs
176
+
177
+ # Saved models
178
+ saved_models
app.py CHANGED
@@ -38,7 +38,7 @@ load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts-msra.safeten
38
 
39
  style_files = os.listdir('./style_images')
40
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
41
- lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
42
  img_size = 512
43
 
44
  cached_style_features = {}
 
38
 
39
  style_files = os.listdir('./style_images')
40
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
41
+ lrs = np.logspace(np.log10(0.0025), np.log10(0.25), 10).tolist()
42
  img_size = 512
43
 
44
  cached_style_features = {}
u2net/inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms
 
4
  import numpy as np
5
  from PIL import Image
6
  import matplotlib.pyplot as plt
@@ -46,12 +47,12 @@ def overlay_segmentation(original_image, binary_mask, alpha=0.5):
46
 
47
  if __name__ == '__main__':
48
  # ---
49
- model_path = 'results/inter-u2net-duts.pt'
50
  image_path = 'images/ladies.jpg'
51
  # ---
52
  model = U2Net().to(device)
53
  model = nn.DataParallel(model)
54
- model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
55
  model.eval()
56
 
57
  mask = run_inference(model, image_path, threshold=None)
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms
4
+ from safetensors.torch import load_file
5
  import numpy as np
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
 
47
 
48
  if __name__ == '__main__':
49
  # ---
50
+ model_path = 'results/u2net-duts-msra.safetensors'
51
  image_path = 'images/ladies.jpg'
52
  # ---
53
  model = U2Net().to(device)
54
  model = nn.DataParallel(model)
55
+ model.load_state_dict(load_file(model_path, device=device.type))
56
  model.eval()
57
 
58
  mask = run_inference(model, image_path, threshold=None)