Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitignore +3 -0
- app.py +1 -1
- u2net/inference.py +3 -2
.gitignore
CHANGED
@@ -173,3 +173,6 @@ cython_debug/
|
|
173 |
|
174 |
# Tensorboard
|
175 |
/runs
|
|
|
|
|
|
|
|
173 |
|
174 |
# Tensorboard
|
175 |
/runs
|
176 |
+
|
177 |
+
# Saved models
|
178 |
+
saved_models
|
app.py
CHANGED
@@ -38,7 +38,7 @@ load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts-msra.safeten
|
|
38 |
|
39 |
style_files = os.listdir('./style_images')
|
40 |
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
|
41 |
-
lrs = np.logspace(np.log10(0.
|
42 |
img_size = 512
|
43 |
|
44 |
cached_style_features = {}
|
|
|
38 |
|
39 |
style_files = os.listdir('./style_images')
|
40 |
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
|
41 |
+
lrs = np.logspace(np.log10(0.0025), np.log10(0.25), 10).tolist()
|
42 |
img_size = 512
|
43 |
|
44 |
cached_style_features = {}
|
u2net/inference.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torchvision import transforms
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
import matplotlib.pyplot as plt
|
@@ -46,12 +47,12 @@ def overlay_segmentation(original_image, binary_mask, alpha=0.5):
|
|
46 |
|
47 |
if __name__ == '__main__':
|
48 |
# ---
|
49 |
-
model_path = 'results/
|
50 |
image_path = 'images/ladies.jpg'
|
51 |
# ---
|
52 |
model = U2Net().to(device)
|
53 |
model = nn.DataParallel(model)
|
54 |
-
model.load_state_dict(
|
55 |
model.eval()
|
56 |
|
57 |
mask = run_inference(model, image_path, threshold=None)
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torchvision import transforms
|
4 |
+
from safetensors.torch import load_file
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
import matplotlib.pyplot as plt
|
|
|
47 |
|
48 |
if __name__ == '__main__':
|
49 |
# ---
|
50 |
+
model_path = 'results/u2net-duts-msra.safetensors'
|
51 |
image_path = 'images/ladies.jpg'
|
52 |
# ---
|
53 |
model = U2Net().to(device)
|
54 |
model = nn.DataParallel(model)
|
55 |
+
model.load_state_dict(load_file(model_path, device=device.type))
|
56 |
model.eval()
|
57 |
|
58 |
mask = run_inference(model, image_path, threshold=None)
|