|
""" |
|
Utilities for working with the local dataset cache. |
|
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp |
|
Copyright by the AllenNLP authors. |
|
""" |
|
|
|
import fnmatch |
|
import json |
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
import tarfile |
|
import tempfile |
|
from contextlib import contextmanager |
|
from functools import partial, wraps |
|
from hashlib import sha256 |
|
from pathlib import Path |
|
from typing import Dict, Optional, Union |
|
from urllib.parse import urlparse |
|
from zipfile import ZipFile, is_zipfile |
|
|
|
import requests |
|
from filelock import FileLock |
|
from tqdm.auto import tqdm |
|
|
|
|
|
__version__ = "3.0.2" |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
|
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): |
|
import torch |
|
|
|
_torch_available = True |
|
logger.info("PyTorch version {} available.".format(torch.__version__)) |
|
else: |
|
logger.info("Disabling PyTorch because USE_TF is set") |
|
_torch_available = False |
|
except ImportError: |
|
_torch_available = False |
|
|
|
try: |
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
|
|
|
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): |
|
import tensorflow as tf |
|
|
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 |
|
_tf_available = True |
|
logger.info("TensorFlow version {} available.".format(tf.__version__)) |
|
else: |
|
logger.info("Disabling Tensorflow because USE_TORCH is set") |
|
_tf_available = False |
|
except (ImportError, AssertionError): |
|
_tf_available = False |
|
|
|
|
|
try: |
|
from torch.hub import _get_torch_home |
|
|
|
torch_cache_home = _get_torch_home() |
|
except ImportError: |
|
torch_cache_home = os.path.expanduser( |
|
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) |
|
) |
|
|
|
|
|
try: |
|
import torch_xla.core.xla_model as xm |
|
|
|
if _torch_available: |
|
_torch_tpu_available = True |
|
else: |
|
_torch_tpu_available = False |
|
except ImportError: |
|
_torch_tpu_available = False |
|
|
|
|
|
try: |
|
import psutil |
|
|
|
_psutil_available = True |
|
|
|
except ImportError: |
|
_psutil_available = False |
|
|
|
|
|
try: |
|
import py3nvml |
|
|
|
_py3nvml_available = True |
|
|
|
except ImportError: |
|
_py3nvml_available = False |
|
|
|
|
|
try: |
|
from apex import amp |
|
|
|
_has_apex = True |
|
except ImportError: |
|
_has_apex = False |
|
|
|
default_cache_path = os.path.join(torch_cache_home, "transformers") |
|
|
|
|
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) |
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) |
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) |
|
|
|
WEIGHTS_NAME = "pytorch_model.bin" |
|
TF2_WEIGHTS_NAME = "tf_model.h5" |
|
TF_WEIGHTS_NAME = "model.ckpt" |
|
CONFIG_NAME = "config.json" |
|
MODEL_CARD_NAME = "modelcard.json" |
|
|
|
|
|
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] |
|
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] |
|
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] |
|
|
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" |
|
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" |
|
|
|
|
|
def is_torch_available(): |
|
return _torch_available |
|
|
|
|
|
def is_tf_available(): |
|
return _tf_available |
|
|
|
|
|
def is_torch_tpu_available(): |
|
return _torch_tpu_available |
|
|
|
|
|
def is_psutil_available(): |
|
return _psutil_available |
|
|
|
|
|
def is_py3nvml_available(): |
|
return _py3nvml_available |
|
|
|
|
|
def is_apex_available(): |
|
return _has_apex |
|
|
|
|
|
def add_start_docstrings(*docstr): |
|
def docstring_decorator(fn): |
|
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") |
|
return fn |
|
|
|
return docstring_decorator |
|
|
|
|
|
def add_start_docstrings_to_callable(*docstr): |
|
def docstring_decorator(fn): |
|
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) |
|
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) |
|
note = r""" |
|
|
|
.. note:: |
|
Although the recipe for forward pass needs to be defined within |
|
this function, one should call the :class:`Module` instance afterwards |
|
instead of this since the former takes care of running the |
|
pre and post processing steps while the latter silently ignores them. |
|
""" |
|
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") |
|
return fn |
|
|
|
return docstring_decorator |
|
|
|
|
|
def add_end_docstrings(*docstr): |
|
def docstring_decorator(fn): |
|
fn.__doc__ = fn.__doc__ + "".join(docstr) |
|
return fn |
|
|
|
return docstring_decorator |
|
|
|
|
|
PT_TOKEN_CLASSIFICATION_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1 |
|
|
|
>>> outputs = model(**inputs, labels=labels) |
|
>>> loss, scores = outputs[:2] |
|
""" |
|
|
|
PT_QUESTION_ANSWERING_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> start_positions = torch.tensor([1]) |
|
>>> end_positions = torch.tensor([3]) |
|
|
|
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) |
|
>>> loss, start_scores, end_scores = outputs[:3] |
|
""" |
|
|
|
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 |
|
>>> outputs = model(**inputs, labels=labels) |
|
>>> loss, logits = outputs[:2] |
|
""" |
|
|
|
PT_MASKED_LM_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] |
|
|
|
>>> outputs = model(input_ids, labels=input_ids) |
|
>>> loss, prediction_scores = outputs[:2] |
|
""" |
|
|
|
PT_BASE_MODEL_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> outputs = model(**inputs) |
|
|
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
|
""" |
|
|
|
PT_MULTIPLE_CHOICE_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import torch |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." |
|
>>> choice0 = "It is eaten with a fork and a knife." |
|
>>> choice1 = "It is eaten while held in the hand." |
|
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 |
|
|
|
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True) |
|
>>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1 |
|
|
|
>>> # the linear classifier still needs to be trained |
|
>>> loss, logits = outputs[:2] |
|
""" |
|
|
|
PT_CAUSAL_LM_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> import torch |
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> outputs = model(**inputs, labels=inputs["input_ids"]) |
|
>>> loss, logits = outputs[:2] |
|
""" |
|
|
|
TF_TOKEN_CLASSIFICATION_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") |
|
>>> input_ids = inputs["input_ids"] |
|
>>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1 |
|
|
|
>>> outputs = model(inputs) |
|
>>> loss, scores = outputs[:2] |
|
""" |
|
|
|
TF_QUESTION_ANSWERING_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" |
|
>>> input_dict = tokenizer(question, text, return_tensors='tf') |
|
>>> start_scores, end_scores = model(input_dict) |
|
|
|
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0]) |
|
>>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) |
|
""" |
|
|
|
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") |
|
>>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1 |
|
|
|
>>> outputs = model(inputs) |
|
>>> loss, logits = outputs[:2] |
|
""" |
|
|
|
TF_MASKED_LM_SAMPLE = r""" |
|
Example:: |
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 |
|
|
|
>>> outputs = model(input_ids) |
|
>>> prediction_scores = outputs[0] |
|
""" |
|
|
|
TF_BASE_MODEL_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") |
|
>>> outputs = model(inputs) |
|
|
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
|
""" |
|
|
|
TF_MULTIPLE_CHOICE_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." |
|
>>> choice0 = "It is eaten with a fork and a knife." |
|
>>> choice1 = "It is eaten while held in the hand." |
|
|
|
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True) |
|
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} |
|
>>> outputs = model(inputs) # batch size is 1 |
|
|
|
>>> # the linear classifier still needs to be trained |
|
>>> logits = outputs[0] |
|
""" |
|
|
|
TF_CAUSAL_LM_SAMPLE = r""" |
|
Example:: |
|
|
|
>>> from transformers import {tokenizer_class}, {model_class} |
|
>>> import tensorflow as tf |
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') |
|
>>> model = {model_class}.from_pretrained('{checkpoint}') |
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") |
|
>>> outputs = model(inputs) |
|
>>> logits = outputs[0] |
|
""" |
|
|
|
|
|
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None): |
|
def docstring_decorator(fn): |
|
model_class = fn.__qualname__.split(".")[0] |
|
is_tf_class = model_class[:2] == "TF" |
|
|
|
if "SequenceClassification" in model_class: |
|
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE |
|
elif "QuestionAnswering" in model_class: |
|
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE |
|
elif "TokenClassification" in model_class: |
|
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE |
|
elif "MultipleChoice" in model_class: |
|
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE |
|
elif "MaskedLM" in model_class: |
|
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE |
|
elif "LMHead" in model_class: |
|
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE |
|
elif "Model" in model_class: |
|
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE |
|
else: |
|
raise ValueError(f"Docstring can't be built for model {model_class}") |
|
|
|
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) |
|
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc |
|
return fn |
|
|
|
return docstring_decorator |
|
|
|
|
|
def is_remote_url(url_or_filename): |
|
parsed = urlparse(url_or_filename) |
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: |
|
""" |
|
Resolve a model identifier, and a file name, to a HF-hosted url |
|
on either S3 or Cloudfront (a Content Delivery Network, or CDN). |
|
|
|
Cloudfront is replicated over the globe so downloads are way faster |
|
for the end user (and it also lowers our bandwidth costs). However, it |
|
is more aggressively cached by default, so may not always reflect the |
|
latest changes to the underlying file (default TTL is 24 hours). |
|
|
|
In terms of client-side caching from this library, even though |
|
Cloudfront relays the ETags from S3, using one or the other |
|
(or switching from one to the other) will affect caching: cached files |
|
are not shared between the two because the cached file's name contains |
|
a hash of the url. |
|
""" |
|
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX |
|
legacy_format = "/" not in model_id |
|
if legacy_format: |
|
return f"{endpoint}/{model_id}-{filename}" |
|
else: |
|
return f"{endpoint}/{model_id}/{filename}" |
|
|
|
|
|
def url_to_filename(url, etag=None): |
|
""" |
|
Convert `url` into a hashed filename in a repeatable way. |
|
If `etag` is specified, append its hash to the url's, delimited |
|
by a period. |
|
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name |
|
so that TF 2.0 can identify it as a HDF5 file |
|
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) |
|
""" |
|
url_bytes = url.encode("utf-8") |
|
url_hash = sha256(url_bytes) |
|
filename = url_hash.hexdigest() |
|
|
|
if etag: |
|
etag_bytes = etag.encode("utf-8") |
|
etag_hash = sha256(etag_bytes) |
|
filename += "." + etag_hash.hexdigest() |
|
|
|
if url.endswith(".h5"): |
|
filename += ".h5" |
|
|
|
return filename |
|
|
|
|
|
def filename_to_url(filename, cache_dir=None): |
|
""" |
|
Return the url and etag (which may be ``None``) stored for `filename`. |
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. |
|
""" |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
if not os.path.exists(cache_path): |
|
raise EnvironmentError("file {} not found".format(cache_path)) |
|
|
|
meta_path = cache_path + ".json" |
|
if not os.path.exists(meta_path): |
|
raise EnvironmentError("file {} not found".format(meta_path)) |
|
|
|
with open(meta_path, encoding="utf-8") as meta_file: |
|
metadata = json.load(meta_file) |
|
url = metadata["url"] |
|
etag = metadata["etag"] |
|
|
|
return url, etag |
|
|
|
|
|
def cached_path( |
|
url_or_filename, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
resume_download=False, |
|
user_agent: Union[Dict, str, None] = None, |
|
extract_compressed_file=False, |
|
force_extract=False, |
|
local_files_only=False, |
|
) -> Optional[str]: |
|
""" |
|
Given something that might be a URL (or might be a local path), |
|
determine which. If it's a URL, download the file and cache it, and |
|
return the path to the cached file. If it's already a local path, |
|
make sure the file exists and then return the path. |
|
Args: |
|
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). |
|
force_download: if True, re-dowload the file even if it's already cached in the cache dir. |
|
resume_download: if True, resume the download if incompletly recieved file is found. |
|
user_agent: Optional string or dict that will be appended to the user-agent on remote requests. |
|
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed |
|
file in a folder along the archive. |
|
force_extract: if True when extract_compressed_file is True and the archive was already extracted, |
|
re-extract the archive and overide the folder where it was extracted. |
|
|
|
Return: |
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). |
|
Local path (string) otherwise |
|
""" |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(url_or_filename, Path): |
|
url_or_filename = str(url_or_filename) |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
if is_remote_url(url_or_filename): |
|
|
|
output_path = get_from_cache( |
|
url_or_filename, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
user_agent=user_agent, |
|
local_files_only=local_files_only, |
|
) |
|
elif os.path.exists(url_or_filename): |
|
|
|
output_path = url_or_filename |
|
elif urlparse(url_or_filename).scheme == "": |
|
|
|
raise EnvironmentError("file {} not found".format(url_or_filename)) |
|
else: |
|
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
|
|
|
if extract_compressed_file: |
|
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): |
|
return output_path |
|
|
|
|
|
|
|
output_dir, output_file = os.path.split(output_path) |
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
|
|
|
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: |
|
return output_path_extracted |
|
|
|
|
|
lock_path = output_path + ".lock" |
|
with FileLock(lock_path): |
|
shutil.rmtree(output_path_extracted, ignore_errors=True) |
|
os.makedirs(output_path_extracted) |
|
if is_zipfile(output_path): |
|
with ZipFile(output_path, "r") as zip_file: |
|
zip_file.extractall(output_path_extracted) |
|
zip_file.close() |
|
elif tarfile.is_tarfile(output_path): |
|
tar_file = tarfile.open(output_path) |
|
tar_file.extractall(output_path_extracted) |
|
tar_file.close() |
|
else: |
|
raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) |
|
|
|
return output_path_extracted |
|
|
|
return output_path |
|
|
|
|
|
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): |
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) |
|
if is_torch_available(): |
|
ua += "; torch/{}".format(torch.__version__) |
|
if is_tf_available(): |
|
ua += "; tensorflow/{}".format(tf.__version__) |
|
if isinstance(user_agent, dict): |
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) |
|
elif isinstance(user_agent, str): |
|
ua += "; " + user_agent |
|
headers = {"user-agent": ua} |
|
if resume_size > 0: |
|
headers["Range"] = "bytes=%d-" % (resume_size,) |
|
response = requests.get(url, stream=True, proxies=proxies, headers=headers) |
|
if response.status_code == 416: |
|
return |
|
content_length = response.headers.get("Content-Length") |
|
total = resume_size + int(content_length) if content_length is not None else None |
|
progress = tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
total=total, |
|
initial=resume_size, |
|
desc="Downloading", |
|
disable=bool(logger.getEffectiveLevel() == logging.NOTSET), |
|
) |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
progress.update(len(chunk)) |
|
temp_file.write(chunk) |
|
progress.close() |
|
|
|
|
|
def get_from_cache( |
|
url, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
etag_timeout=10, |
|
resume_download=False, |
|
user_agent: Union[Dict, str, None] = None, |
|
local_files_only=False, |
|
) -> Optional[str]: |
|
""" |
|
Given a URL, look for the corresponding file in the local cache. |
|
If it's not there, download it. Then return the path to the cached file. |
|
|
|
Return: |
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). |
|
Local path (string) otherwise |
|
""" |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
etag = None |
|
if not local_files_only: |
|
try: |
|
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) |
|
if response.status_code == 200: |
|
etag = response.headers.get("ETag") |
|
except (EnvironmentError, requests.exceptions.Timeout): |
|
|
|
pass |
|
|
|
filename = url_to_filename(url, etag) |
|
|
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
|
|
|
|
|
|
if etag is None: |
|
if os.path.exists(cache_path): |
|
return cache_path |
|
else: |
|
matching_files = [ |
|
file |
|
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") |
|
if not file.endswith(".json") and not file.endswith(".lock") |
|
] |
|
if len(matching_files) > 0: |
|
return os.path.join(cache_dir, matching_files[-1]) |
|
else: |
|
|
|
|
|
|
|
if local_files_only: |
|
raise ValueError( |
|
"Cannot find the requested files in the cached path and outgoing traffic has been" |
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" |
|
" to False." |
|
) |
|
return None |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
return cache_path |
|
|
|
|
|
lock_path = cache_path + ".lock" |
|
with FileLock(lock_path): |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
|
|
return cache_path |
|
|
|
if resume_download: |
|
incomplete_path = cache_path + ".incomplete" |
|
|
|
@contextmanager |
|
def _resumable_file_manager(): |
|
with open(incomplete_path, "a+b") as f: |
|
yield f |
|
|
|
temp_file_manager = _resumable_file_manager |
|
if os.path.exists(incomplete_path): |
|
resume_size = os.stat(incomplete_path).st_size |
|
else: |
|
resume_size = 0 |
|
else: |
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) |
|
resume_size = 0 |
|
|
|
|
|
|
|
with temp_file_manager() as temp_file: |
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) |
|
|
|
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) |
|
|
|
logger.info("storing %s in cache at %s", url, cache_path) |
|
os.replace(temp_file.name, cache_path) |
|
|
|
logger.info("creating metadata file for %s", cache_path) |
|
meta = {"url": url, "etag": etag} |
|
meta_path = cache_path + ".json" |
|
with open(meta_path, "w") as meta_file: |
|
json.dump(meta, meta_file) |
|
|
|
return cache_path |
|
|
|
|
|
class cached_property(property): |
|
""" |
|
Descriptor that mimics @property but caches output in member variable. |
|
|
|
From tensorflow_datasets |
|
|
|
Built-in in functools from Python 3.8. |
|
""" |
|
|
|
def __get__(self, obj, objtype=None): |
|
|
|
if obj is None: |
|
return self |
|
if self.fget is None: |
|
raise AttributeError("unreadable attribute") |
|
attr = "__cached_" + self.fget.__name__ |
|
cached = getattr(obj, attr, None) |
|
if cached is None: |
|
cached = self.fget(obj) |
|
setattr(obj, attr, cached) |
|
return cached |
|
|
|
|
|
def torch_required(func): |
|
|
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if is_torch_available(): |
|
return func(*args, **kwargs) |
|
else: |
|
raise ImportError(f"Method `{func.__name__}` requires PyTorch.") |
|
|
|
return wrapper |
|
|
|
|
|
def tf_required(func): |
|
|
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if is_tf_available(): |
|
return func(*args, **kwargs) |
|
else: |
|
raise ImportError(f"Method `{func.__name__}` requires TF.") |
|
|
|
return wrapper |
|
|