Spaces:
Sleeping
Sleeping
Create modules/safe.py
Browse files- modules/safe.py +188 -0
modules/safe.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is adapted from the script contributed by anon from /h/
|
2 |
+
# modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
|
3 |
+
|
4 |
+
import io
|
5 |
+
import pickle
|
6 |
+
import collections
|
7 |
+
import sys
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy
|
12 |
+
import _codecs
|
13 |
+
import zipfile
|
14 |
+
import re
|
15 |
+
|
16 |
+
|
17 |
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
18 |
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
19 |
+
|
20 |
+
|
21 |
+
def encode(*args):
|
22 |
+
out = _codecs.encode(*args)
|
23 |
+
return out
|
24 |
+
|
25 |
+
|
26 |
+
class RestrictedUnpickler(pickle.Unpickler):
|
27 |
+
extra_handler = None
|
28 |
+
|
29 |
+
def persistent_load(self, saved_id):
|
30 |
+
assert saved_id[0] == 'storage'
|
31 |
+
return TypedStorage()
|
32 |
+
|
33 |
+
def find_class(self, module, name):
|
34 |
+
if self.extra_handler is not None:
|
35 |
+
res = self.extra_handler(module, name)
|
36 |
+
if res is not None:
|
37 |
+
return res
|
38 |
+
|
39 |
+
if module == 'collections' and name == 'OrderedDict':
|
40 |
+
return getattr(collections, name)
|
41 |
+
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
42 |
+
return getattr(torch._utils, name)
|
43 |
+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
44 |
+
return getattr(torch, name)
|
45 |
+
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
46 |
+
return getattr(torch.nn.modules.container, name)
|
47 |
+
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
48 |
+
return getattr(numpy.core.multiarray, name)
|
49 |
+
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
50 |
+
return getattr(numpy, name)
|
51 |
+
if module == '_codecs' and name == 'encode':
|
52 |
+
return encode
|
53 |
+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
54 |
+
import pytorch_lightning.callbacks
|
55 |
+
return pytorch_lightning.callbacks.model_checkpoint
|
56 |
+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
57 |
+
import pytorch_lightning.callbacks.model_checkpoint
|
58 |
+
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
59 |
+
if module == "__builtin__" and name == 'set':
|
60 |
+
return set
|
61 |
+
|
62 |
+
# Forbid everything else.
|
63 |
+
raise Exception(f"global '{module}/{name}' is forbidden")
|
64 |
+
|
65 |
+
|
66 |
+
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
67 |
+
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
68 |
+
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
69 |
+
|
70 |
+
def check_zip_filenames(filename, names):
|
71 |
+
for name in names:
|
72 |
+
if allowed_zip_names_re.match(name):
|
73 |
+
continue
|
74 |
+
|
75 |
+
raise Exception(f"bad file inside {filename}: {name}")
|
76 |
+
|
77 |
+
|
78 |
+
def check_pt(filename, extra_handler):
|
79 |
+
try:
|
80 |
+
|
81 |
+
# new pytorch format is a zip file
|
82 |
+
with zipfile.ZipFile(filename) as z:
|
83 |
+
check_zip_filenames(filename, z.namelist())
|
84 |
+
|
85 |
+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
86 |
+
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
87 |
+
if len(data_pkl_filenames) == 0:
|
88 |
+
raise Exception(f"data.pkl not found in {filename}")
|
89 |
+
if len(data_pkl_filenames) > 1:
|
90 |
+
raise Exception(f"Multiple data.pkl found in {filename}")
|
91 |
+
with z.open(data_pkl_filenames[0]) as file:
|
92 |
+
unpickler = RestrictedUnpickler(file)
|
93 |
+
unpickler.extra_handler = extra_handler
|
94 |
+
unpickler.load()
|
95 |
+
|
96 |
+
except zipfile.BadZipfile:
|
97 |
+
|
98 |
+
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
99 |
+
with open(filename, "rb") as file:
|
100 |
+
unpickler = RestrictedUnpickler(file)
|
101 |
+
unpickler.extra_handler = extra_handler
|
102 |
+
for i in range(5):
|
103 |
+
unpickler.load()
|
104 |
+
|
105 |
+
|
106 |
+
def load(filename, *args, **kwargs):
|
107 |
+
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
108 |
+
|
109 |
+
|
110 |
+
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
111 |
+
"""
|
112 |
+
this function is intended to be used by extensions that want to load models with
|
113 |
+
some extra classes in them that the usual unpickler would find suspicious.
|
114 |
+
|
115 |
+
Use the extra_handler argument to specify a function that takes module and field name as text,
|
116 |
+
and returns that field's value:
|
117 |
+
|
118 |
+
```python
|
119 |
+
def extra(module, name):
|
120 |
+
if module == 'collections' and name == 'OrderedDict':
|
121 |
+
return collections.OrderedDict
|
122 |
+
|
123 |
+
return None
|
124 |
+
|
125 |
+
safe.load_with_extra('model.pt', extra_handler=extra)
|
126 |
+
```
|
127 |
+
|
128 |
+
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
129 |
+
definitely unsafe.
|
130 |
+
"""
|
131 |
+
|
132 |
+
try:
|
133 |
+
check_pt(filename, extra_handler)
|
134 |
+
|
135 |
+
except pickle.UnpicklingError:
|
136 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
137 |
+
print(traceback.format_exc(), file=sys.stderr)
|
138 |
+
print("The file is most likely corrupted.", file=sys.stderr)
|
139 |
+
return None
|
140 |
+
|
141 |
+
except Exception:
|
142 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
143 |
+
print(traceback.format_exc(), file=sys.stderr)
|
144 |
+
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
145 |
+
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
146 |
+
return None
|
147 |
+
|
148 |
+
return unsafe_torch_load(filename, *args, **kwargs)
|
149 |
+
|
150 |
+
|
151 |
+
class Extra:
|
152 |
+
"""
|
153 |
+
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
154 |
+
(because it's not your code making the torch.load call). The intended use is like this:
|
155 |
+
|
156 |
+
```
|
157 |
+
import torch
|
158 |
+
from modules import safe
|
159 |
+
|
160 |
+
def handler(module, name):
|
161 |
+
if module == 'torch' and name in ['float64', 'float16']:
|
162 |
+
return getattr(torch, name)
|
163 |
+
|
164 |
+
return None
|
165 |
+
|
166 |
+
with safe.Extra(handler):
|
167 |
+
x = torch.load('model.pt')
|
168 |
+
```
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, handler):
|
172 |
+
self.handler = handler
|
173 |
+
|
174 |
+
def __enter__(self):
|
175 |
+
global global_extra_handler
|
176 |
+
|
177 |
+
assert global_extra_handler is None, 'already inside an Extra() block'
|
178 |
+
global_extra_handler = self.handler
|
179 |
+
|
180 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
181 |
+
global global_extra_handler
|
182 |
+
|
183 |
+
global_extra_handler = None
|
184 |
+
|
185 |
+
|
186 |
+
unsafe_torch_load = torch.load
|
187 |
+
torch.load = load
|
188 |
+
global_extra_handler = None
|