import hashlib import os import urllib import warnings from tqdm import tqdm _RN50 = dict( openai="", yfcc15m="", cc12m="" ) _RN50_quickgelu = dict( openai="", yfcc15m="", cc12m="" ) _RN101 = dict( openai="", yfcc15m="" ) _RN101_quickgelu = dict( openai="", yfcc15m="" ) _RN50x4 = dict( openai="", ) _RN50x16 = dict( openai="", ) _RN50x64 = dict( openai="", ) _VITB32 = dict( openai="", laion400m_e31="", laion400m_e32="", laion400m_avg="", ) _VITB32_quickgelu = dict( openai="", laion400m_e31="", laion400m_e32="", laion400m_avg="", ) _VITB16 = dict( openai="", ) _VITL14 = dict( openai="", ) _PRETRAINED = { "RN50": _RN50, "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "ViT-B-32": _VITB32, "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, "ViT-L-14": _VITL14, } def list_pretrained(as_str: bool = False): """ returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_tag_models(tag: str): """ return all models having the specified pretrain tag """ models = [] for k in _PRETRAINED.keys(): if tag in _PRETRAINED[k]: models.append(k) return models def list_pretrained_model_tags(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) return tags def get_pretrained_url(model: str, tag: str): if model not in _PRETRAINED: return '' model_pretrained = _PRETRAINED[model] if tag not in model_pretrained: return '' return model_pretrained[tag] def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) if 'openaipublic' in url: expected_sha256 = url.split("/")[-2] else: expected_sha256 = '' download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if expected_sha256: if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = if not buffer: break output.write(buffer) loop.update(len(buffer)) if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target