shunxing1234 commited on
Commit
f65b4f1
1 Parent(s): af5e5f3

Upload ZEN/file_utils.py

Browse files
Files changed (1) hide show
  1. ZEN/file_utils.py +287 -0
ZEN/file_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is derived from the code at
2
+ # https://github.com/huggingface/transformers/blob/master/transformers/file_utils.py
3
+ # and the code at
4
+ # https://github.com/allenai/allennlp/blob/master/allennlp/common/file_utils.py.
5
+ #
6
+ # Original copyright notice:
7
+ #
8
+ # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
9
+ # Copyright by the AllenNLP authors.
10
+ """Utilities for working with the local dataset cache."""
11
+
12
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
13
+
14
+ import sys
15
+ import json
16
+ import logging
17
+ import os
18
+ import shutil
19
+ import tempfile
20
+ import fnmatch
21
+ from functools import wraps
22
+ from hashlib import sha256
23
+ import sys
24
+ from io import open
25
+
26
+ import boto3
27
+ import requests
28
+ from botocore.exceptions import ClientError
29
+ from tqdm import tqdm
30
+
31
+ try:
32
+ from torch.hub import _get_torch_home
33
+
34
+ torch_cache_home = _get_torch_home()
35
+ except ImportError:
36
+ torch_cache_home = os.path.expanduser(
37
+ os.getenv('TORCH_HOME', os.path.join(
38
+ os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
39
+ default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
40
+
41
+ try:
42
+ from urllib.parse import urlparse
43
+ except ImportError:
44
+ from urlparse import urlparse
45
+
46
+ try:
47
+ from pathlib import Path
48
+
49
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(
50
+ os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
51
+ except (AttributeError, ImportError):
52
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
53
+ default_cache_path)
54
+
55
+ CONFIG_NAME = "config.json"
56
+ WEIGHTS_NAME = "pytorch_model.bin"
57
+
58
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
59
+
60
+
61
+ def url_to_filename(url, etag=None):
62
+ """
63
+ Convert `url` into a hashed filename in a repeatable way.
64
+ If `etag` is specified, append its hash to the url's, delimited
65
+ by a period.
66
+ """
67
+ url_bytes = url.encode('utf-8')
68
+ url_hash = sha256(url_bytes)
69
+ filename = url_hash.hexdigest()
70
+
71
+ if etag:
72
+ etag_bytes = etag.encode('utf-8')
73
+ etag_hash = sha256(etag_bytes)
74
+ filename += '.' + etag_hash.hexdigest()
75
+
76
+ return filename
77
+
78
+
79
+ def filename_to_url(filename, cache_dir=None):
80
+ """
81
+ Return the url and etag (which may be ``None``) stored for `filename`.
82
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
83
+ """
84
+ if cache_dir is None:
85
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
86
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
87
+ cache_dir = str(cache_dir)
88
+
89
+ cache_path = os.path.join(cache_dir, filename)
90
+ if not os.path.exists(cache_path):
91
+ raise EnvironmentError("file {} not found".format(cache_path))
92
+
93
+ meta_path = cache_path + '.json'
94
+ if not os.path.exists(meta_path):
95
+ raise EnvironmentError("file {} not found".format(meta_path))
96
+
97
+ with open(meta_path, encoding="utf-8") as meta_file:
98
+ metadata = json.load(meta_file)
99
+ url = metadata['url']
100
+ etag = metadata['etag']
101
+
102
+ return url, etag
103
+
104
+
105
+ def cached_path(url_or_filename, cache_dir=None):
106
+ """
107
+ Given something that might be a URL (or might be a local path),
108
+ determine which. If it's a URL, download the file and cache it, and
109
+ return the path to the cached file. If it's already a local path,
110
+ make sure the file exists and then return the path.
111
+ """
112
+ if cache_dir is None:
113
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
114
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
115
+ url_or_filename = str(url_or_filename)
116
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
117
+ cache_dir = str(cache_dir)
118
+
119
+ parsed = urlparse(url_or_filename)
120
+
121
+ if parsed.scheme in ('http', 'https', 's3'):
122
+ # URL, so get it from the cache (downloading if necessary)
123
+ return get_from_cache(url_or_filename, cache_dir)
124
+ elif os.path.exists(url_or_filename):
125
+ # File, and it exists.
126
+ return url_or_filename
127
+ elif parsed.scheme == '':
128
+ # File, but it doesn't exist.
129
+ raise EnvironmentError("file {} not found".format(url_or_filename))
130
+ else:
131
+ # Something unknown
132
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
133
+
134
+
135
+ def split_s3_path(url):
136
+ """Split a full s3 path into the bucket name and path."""
137
+ parsed = urlparse(url)
138
+ if not parsed.netloc or not parsed.path:
139
+ raise ValueError("bad s3 path {}".format(url))
140
+ bucket_name = parsed.netloc
141
+ s3_path = parsed.path
142
+ # Remove '/' at beginning of path.
143
+ if s3_path.startswith("/"):
144
+ s3_path = s3_path[1:]
145
+ return bucket_name, s3_path
146
+
147
+
148
+ def s3_request(func):
149
+ """
150
+ Wrapper function for s3 requests in order to create more helpful error
151
+ messages.
152
+ """
153
+
154
+ @wraps(func)
155
+ def wrapper(url, *args, **kwargs):
156
+ try:
157
+ return func(url, *args, **kwargs)
158
+ except ClientError as exc:
159
+ if int(exc.response["Error"]["Code"]) == 404:
160
+ raise EnvironmentError("file {} not found".format(url))
161
+ else:
162
+ raise
163
+
164
+ return wrapper
165
+
166
+
167
+ @s3_request
168
+ def s3_etag(url):
169
+ """Check ETag on S3 object."""
170
+ s3_resource = boto3.resource("s3")
171
+ bucket_name, s3_path = split_s3_path(url)
172
+ s3_object = s3_resource.Object(bucket_name, s3_path)
173
+ return s3_object.e_tag
174
+
175
+
176
+ @s3_request
177
+ def s3_get(url, temp_file):
178
+ """Pull a file directly from S3."""
179
+ s3_resource = boto3.resource("s3")
180
+ bucket_name, s3_path = split_s3_path(url)
181
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
182
+
183
+
184
+ def http_get(url, temp_file):
185
+ req = requests.get(url, stream=True)
186
+ content_length = req.headers.get('Content-Length')
187
+ total = int(content_length) if content_length is not None else None
188
+ progress = tqdm(unit="B", total=total)
189
+ for chunk in req.iter_content(chunk_size=1024):
190
+ if chunk: # filter out keep-alive new chunks
191
+ progress.update(len(chunk))
192
+ temp_file.write(chunk)
193
+ progress.close()
194
+
195
+
196
+ def get_from_cache(url, cache_dir=None):
197
+ """
198
+ Given a URL, look for the corresponding dataset in the local cache.
199
+ If it's not there, download it. Then return the path to the cached file.
200
+ """
201
+ if cache_dir is None:
202
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
203
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
204
+ cache_dir = str(cache_dir)
205
+
206
+ if not os.path.exists(cache_dir):
207
+ os.makedirs(cache_dir)
208
+
209
+ # Get eTag to add to filename, if it exists.
210
+ if url.startswith("s3://"):
211
+ etag = s3_etag(url)
212
+ else:
213
+ try:
214
+ response = requests.head(url, allow_redirects=True)
215
+ if response.status_code != 200:
216
+ etag = None
217
+ else:
218
+ etag = response.headers.get("ETag")
219
+ except EnvironmentError:
220
+ etag = None
221
+
222
+ if sys.version_info[0] == 2 and etag is not None:
223
+ etag = etag.decode('utf-8')
224
+ filename = url_to_filename(url, etag)
225
+
226
+ # get cache path to put the file
227
+ cache_path = os.path.join(cache_dir, filename)
228
+
229
+ # If we don't have a connection (etag is None) and can't identify the file
230
+ # try to get the last downloaded one
231
+ if not os.path.exists(cache_path) and etag is None:
232
+ matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
233
+ matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
234
+ if matching_files:
235
+ cache_path = os.path.join(cache_dir, matching_files[-1])
236
+
237
+ if not os.path.exists(cache_path):
238
+ # Download to temporary file, then copy to cache dir once finished.
239
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
240
+ with tempfile.NamedTemporaryFile() as temp_file:
241
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
242
+
243
+ # GET file object
244
+ if url.startswith("s3://"):
245
+ s3_get(url, temp_file)
246
+ else:
247
+ http_get(url, temp_file)
248
+
249
+ # we are copying the file before closing it, so flush to avoid truncation
250
+ temp_file.flush()
251
+ # shutil.copyfileobj() starts at the current position, so go to the start
252
+ temp_file.seek(0)
253
+
254
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
255
+ with open(cache_path, 'wb') as cache_file:
256
+ shutil.copyfileobj(temp_file, cache_file)
257
+
258
+ logger.info("creating metadata file for %s", cache_path)
259
+ meta = {'url': url, 'etag': etag}
260
+ meta_path = cache_path + '.json'
261
+ with open(meta_path, 'w') as meta_file:
262
+ output_string = json.dumps(meta)
263
+ if sys.version_info[0] == 2 and isinstance(output_string, str):
264
+ output_string = unicode(output_string, 'utf-8') # The beauty of python 2
265
+ meta_file.write(output_string)
266
+
267
+ logger.info("removing temp file %s", temp_file.name)
268
+
269
+ return cache_path
270
+
271
+
272
+ def read_set_from_file(filename):
273
+ '''
274
+ Extract a de-duped collection (set) of text from a file.
275
+ Expected file format is one item per line.
276
+ '''
277
+ collection = set()
278
+ with open(filename, 'r', encoding='utf-8') as file_:
279
+ for line in file_:
280
+ collection.add(line.rstrip())
281
+ return collection
282
+
283
+
284
+ def get_file_extension(path, dot=True, lower=True):
285
+ ext = os.path.splitext(path)[1]
286
+ ext = ext if dot else ext[1:]
287
+ return ext.lower() if lower else ext