Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
•
2a8678c
1
Parent(s):
4fcfd85
update
Browse files- src/img_util.py +3 -1
src/img_util.py
CHANGED
@@ -2,6 +2,8 @@ import einops
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
|
|
5 |
|
6 |
@torch.no_grad()
|
7 |
def find_flat_region(mask):
|
@@ -18,6 +20,6 @@ def find_flat_region(mask):
|
|
18 |
|
19 |
|
20 |
def numpy2tensor(img):
|
21 |
-
x0 = torch.from_numpy(img.copy()).float().
|
22 |
x0 = torch.stack([x0], dim=0)
|
23 |
return einops.rearrange(x0, 'b h w c -> b c h w').clone()
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
+
|
7 |
|
8 |
@torch.no_grad()
|
9 |
def find_flat_region(mask):
|
|
|
20 |
|
21 |
|
22 |
def numpy2tensor(img):
|
23 |
+
x0 = torch.from_numpy(img.copy()).float().to(device) / 255.0 * 2.0 - 1.
|
24 |
x0 = torch.stack([x0], dim=0)
|
25 |
return einops.rearrange(x0, 'b h w c -> b c h w').clone()
|