Update app.py
Browse files
app.py
CHANGED
@@ -38,7 +38,10 @@ from urllib.request import urlretrieve
|
|
38 |
from scipy.interpolate import LinearNDInterpolator
|
39 |
from imageio import imread, imwrite
|
40 |
|
|
|
41 |
|
|
|
|
|
42 |
|
43 |
def write_flo(flow, filename):
|
44 |
"""
|
@@ -116,10 +119,7 @@ def infer(frameA, frameB):
|
|
116 |
|
117 |
|
118 |
# If you can, run this example on a GPU, it will be a lot faster.
|
119 |
-
|
120 |
-
|
121 |
-
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
|
122 |
-
model = model.eval()
|
123 |
|
124 |
list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
|
125 |
print(f"list_of_flows type = {type(list_of_flows)}")
|
|
|
38 |
from scipy.interpolate import LinearNDInterpolator
|
39 |
from imageio import imread, imwrite
|
40 |
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
|
43 |
+
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
|
44 |
+
model = model.eval()
|
45 |
|
46 |
def write_flo(flow, filename):
|
47 |
"""
|
|
|
119 |
|
120 |
|
121 |
# If you can, run this example on a GPU, it will be a lot faster.
|
122 |
+
|
|
|
|
|
|
|
123 |
|
124 |
list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
|
125 |
print(f"list_of_flows type = {type(list_of_flows)}")
|