Spaces:
Running
on
A10G
Running
on
A10G
FrozenBurning
commited on
Commit
•
8eda766
1
Parent(s):
d099347
Update app.py
Browse files
app.py
CHANGED
@@ -11,81 +11,6 @@ os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
|
|
11 |
os.system("cp -r SceneDreamer/* ./")
|
12 |
os.system("bash install.sh")
|
13 |
|
14 |
-
pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
|
15 |
-
alt_url='', file_size=330571863,
|
16 |
-
file_path='./scenedreamer_released.pt',)
|
17 |
-
|
18 |
-
|
19 |
-
def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
|
20 |
-
file_path = file_spec['file_path']
|
21 |
-
if use_alt_url:
|
22 |
-
file_url = file_spec['alt_url']
|
23 |
-
else:
|
24 |
-
file_url = file_spec['file_url']
|
25 |
-
|
26 |
-
file_dir = os.path.dirname(file_path)
|
27 |
-
tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
|
28 |
-
if file_dir:
|
29 |
-
os.makedirs(file_dir, exist_ok=True)
|
30 |
-
|
31 |
-
progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
|
32 |
-
for attempts_left in reversed(range(num_attempts)):
|
33 |
-
data_size = 0
|
34 |
-
progress_bar.reset()
|
35 |
-
try:
|
36 |
-
# Download.
|
37 |
-
data_md5 = hashlib.md5()
|
38 |
-
with session.get(file_url, stream=True) as res:
|
39 |
-
res.raise_for_status()
|
40 |
-
with open(tmp_path, 'wb') as f:
|
41 |
-
for chunk in res.iter_content(chunk_size=chunk_size<<10):
|
42 |
-
progress_bar.update(len(chunk))
|
43 |
-
f.write(chunk)
|
44 |
-
data_size += len(chunk)
|
45 |
-
data_md5.update(chunk)
|
46 |
-
|
47 |
-
# Validate.
|
48 |
-
if 'file_size' in file_spec and data_size != file_spec['file_size']:
|
49 |
-
raise IOError('Incorrect file size', file_path)
|
50 |
-
if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
|
51 |
-
raise IOError('Incorrect file MD5', file_path)
|
52 |
-
break
|
53 |
-
|
54 |
-
except Exception as e:
|
55 |
-
# print(e)
|
56 |
-
# Last attempt => raise error.
|
57 |
-
if not attempts_left:
|
58 |
-
raise
|
59 |
-
|
60 |
-
# Handle Google Drive virus checker nag.
|
61 |
-
if data_size > 0 and data_size < 8192:
|
62 |
-
with open(tmp_path, 'rb') as f:
|
63 |
-
data = f.read()
|
64 |
-
links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
|
65 |
-
if len(links) == 1:
|
66 |
-
file_url = requests.compat.urljoin(file_url, links[0])
|
67 |
-
continue
|
68 |
-
|
69 |
-
progress_bar.close()
|
70 |
-
|
71 |
-
# Rename temp file to the correct name.
|
72 |
-
os.replace(tmp_path, file_path) # atomic
|
73 |
-
|
74 |
-
# Attempt to clean up any leftover temps.
|
75 |
-
for filename in glob.glob(file_path + '.tmp.*'):
|
76 |
-
try:
|
77 |
-
os.remove(filename)
|
78 |
-
except:
|
79 |
-
pass
|
80 |
-
|
81 |
-
print('Downloading SceneDreamer pretrained model...')
|
82 |
-
with requests.Session() as session:
|
83 |
-
try:
|
84 |
-
download_file(session, pretrained_model)
|
85 |
-
except:
|
86 |
-
print('Google Drive download failed.\n')
|
87 |
-
|
88 |
-
|
89 |
|
90 |
import os
|
91 |
import torch
|
|
|
11 |
os.system("cp -r SceneDreamer/* ./")
|
12 |
os.system("bash install.sh")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
import os
|
16 |
import torch
|