|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TensorFlow custom ops builder. |
|
""" |
|
|
|
import glob |
|
import os |
|
import re |
|
import uuid |
|
import hashlib |
|
import tempfile |
|
import shutil |
|
import tensorflow as tf |
|
from tensorflow.python.client import device_lib |
|
|
|
from .. import util |
|
|
|
|
|
|
|
|
|
cuda_cache_path = None |
|
cuda_cache_version_tag = 'v1' |
|
do_not_hash_included_headers = True |
|
verbose = False |
|
|
|
|
|
|
|
|
|
def _find_compiler_bindir(): |
|
hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) |
|
if hostx64_paths != []: |
|
return hostx64_paths[0] |
|
hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) |
|
if hostx64_paths != []: |
|
return hostx64_paths[0] |
|
hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) |
|
if hostx64_paths != []: |
|
return hostx64_paths[0] |
|
vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' |
|
if os.path.isdir(vc_bin_dir): |
|
return vc_bin_dir |
|
return None |
|
|
|
def _get_compute_cap(device): |
|
caps_str = device.physical_device_desc |
|
m = re.search('compute capability: (\\d+).(\\d+)', caps_str) |
|
major = m.group(1) |
|
minor = m.group(2) |
|
return (major, minor) |
|
|
|
def _get_cuda_gpu_arch_string(): |
|
gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] |
|
if len(gpus) == 0: |
|
raise RuntimeError('No GPU devices found') |
|
(major, minor) = _get_compute_cap(gpus[0]) |
|
return 'sm_%s%s' % (major, minor) |
|
|
|
def _run_cmd(cmd): |
|
with os.popen(cmd) as pipe: |
|
output = pipe.read() |
|
status = pipe.close() |
|
if status is not None: |
|
raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) |
|
|
|
def _prepare_nvcc_cli(opts): |
|
cmd = 'nvcc --std=c++11 -DNDEBUG ' + opts.strip() |
|
cmd += ' --disable-warnings' |
|
cmd += ' --include-path "%s"' % tf.sysconfig.get_include() |
|
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') |
|
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') |
|
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') |
|
|
|
compiler_bindir = _find_compiler_bindir() |
|
if compiler_bindir is None: |
|
|
|
|
|
if os.name == 'nt': |
|
raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) |
|
else: |
|
cmd += ' --compiler-bindir "%s"' % compiler_bindir |
|
cmd += ' 2>&1' |
|
return cmd |
|
|
|
|
|
|
|
|
|
_plugin_cache = dict() |
|
|
|
def get_plugin(cuda_file, extra_nvcc_options=[]): |
|
cuda_file_base = os.path.basename(cuda_file) |
|
cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) |
|
|
|
|
|
if cuda_file in _plugin_cache: |
|
return _plugin_cache[cuda_file] |
|
|
|
|
|
if verbose: |
|
print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) |
|
try: |
|
|
|
md5 = hashlib.md5() |
|
with open(cuda_file, 'rb') as f: |
|
md5.update(f.read()) |
|
md5.update(b'\n') |
|
|
|
|
|
if not do_not_hash_included_headers: |
|
if verbose: |
|
print('Preprocessing... ', end='', flush=True) |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) |
|
_run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) |
|
with open(tmp_file, 'rb') as f: |
|
bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') |
|
good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') |
|
for ln in f: |
|
if not ln.startswith(b'# ') and not ln.startswith(b'#line '): |
|
ln = ln.replace(bad_file_str, good_file_str) |
|
md5.update(ln) |
|
md5.update(b'\n') |
|
|
|
|
|
compile_opts = '' |
|
if os.name == 'nt': |
|
compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') |
|
elif os.name == 'posix': |
|
compile_opts += f' --compiler-options \'-fPIC\'' |
|
compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\'' |
|
compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\'' |
|
else: |
|
assert False |
|
compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}' |
|
compile_opts += ' --use_fast_math' |
|
for opt in extra_nvcc_options: |
|
compile_opts += ' ' + opt |
|
nvcc_cmd = _prepare_nvcc_cli(compile_opts) |
|
|
|
|
|
md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') |
|
md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') |
|
md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') |
|
|
|
|
|
cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path |
|
bin_file_ext = '.dll' if os.name == 'nt' else '.so' |
|
bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) |
|
if not os.path.isfile(bin_file): |
|
if verbose: |
|
print('Compiling... ', end='', flush=True) |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) |
|
_run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) |
|
os.makedirs(cache_dir, exist_ok=True) |
|
intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) |
|
shutil.copyfile(tmp_file, intermediate_file) |
|
os.rename(intermediate_file, bin_file) |
|
|
|
|
|
if verbose: |
|
print('Loading... ', end='', flush=True) |
|
plugin = tf.load_op_library(bin_file) |
|
|
|
|
|
_plugin_cache[cuda_file] = plugin |
|
if verbose: |
|
print('Done.', flush=True) |
|
return plugin |
|
|
|
except: |
|
if verbose: |
|
print('Failed!', flush=True) |
|
raise |
|
|
|
|
|
|