test / tests /test_requirements.py
iblfe's picture
Upload folder using huggingface_hub
b585c7f verified
raw
history blame
4.29 kB
import pkg_resources
from pkg_resources import DistributionNotFound, VersionConflict
from src.utils import remove
from tests.utils import wrap_test_forked
def get_all_requirements():
import glob
requirements_all = []
reqs_http_all = []
for req_name in ['requirements.txt'] + glob.glob('reqs_optional/req*.txt'):
requirements1, reqs_http1 = get_requirements(req_name)
requirements_all.extend(requirements1)
reqs_http_all.extend(reqs_http1)
return requirements_all, reqs_http_all
def get_requirements(req_file="requirements.txt"):
req_tmp_file = req_file + '.tmp.txt'
try:
reqs_http = []
with open(req_file, 'rt') as f:
contents = f.readlines()
with open(req_tmp_file, 'wt') as g:
for line in contents:
if 'http://' not in line and 'https://' not in line:
g.write(line)
else:
reqs_http.append(line.replace('\n', ''))
reqs_http = [x for x in reqs_http if x]
print('reqs_http: %s' % reqs_http, flush=True)
with open(req_tmp_file, "rt") as f:
requirements = pkg_resources.parse_requirements(f.read())
finally:
remove(req_tmp_file)
return requirements, reqs_http
@wrap_test_forked
def test_requirements():
"""Test that each required package is available."""
packages_all = []
packages_dist = []
packages_version = []
packages_unkn = []
requirements, reqs_http = get_all_requirements()
for requirement in requirements:
try:
requirement = str(requirement)
pkg_resources.require(requirement)
except DistributionNotFound:
packages_all.append(requirement)
packages_dist.append(requirement)
except VersionConflict:
packages_all.append(requirement)
packages_version.append(requirement)
except pkg_resources.extern.packaging.requirements.InvalidRequirement:
packages_all.append(requirement)
packages_unkn.append(requirement)
packages_all.extend(reqs_http)
if packages_dist or packages_version:
print('Missing packages: %s' % packages_dist, flush=True)
print('Wrong version of packages: %s' % packages_version, flush=True)
print("Can't determine (e.g. http) packages: %s" % packages_unkn, flush=True)
print('\n\nRUN THIS:\n\n', flush=True)
print(
'pip uninstall peft transformers accelerate -y ; CUDA_HOME=/usr/local/cuda-11.7 pip install %s --upgrade' % str(
' '.join(packages_all)), flush=True)
print('\n\n', flush=True)
raise ValueError(packages_all)
import requests
import json
try:
from packaging.version import parse
except ImportError:
from pip._vendor.packaging.version import parse
URL_PATTERN = 'https://pypi.python.org/pypi/{package}/json'
def get_version(package, url_pattern=URL_PATTERN):
"""Return version of package on pypi.python.org using json."""
req = requests.get(url_pattern.format(package=package))
version = parse('0')
if req.status_code == requests.codes.ok:
j = json.loads(req.text.encode(req.encoding))
releases = j.get('releases', [])
for release in releases:
ver = parse(release)
if not ver.is_prerelease:
version = max(version, ver)
return version
@wrap_test_forked
def test_what_latest_packages():
# pip install requirements-parser
import requirements
import glob
for req_name in ['requirements.txt'] + glob.glob('reqs_optional/req*.txt'):
print("\n File: %s" % req_name, flush=True)
with open(req_name, 'rt') as fd:
for req in requirements.parse(fd):
from importlib.metadata import version
try:
current_version = version(req.name)
latest_version = get_version(req.name)
if str(current_version) != str(latest_version):
print("%s: %s -> %s" % (req.name, current_version, latest_version), flush=True)
except Exception as e:
print("Exception: %s" % str(e), flush=True)