|
import subprocess |
|
import os |
|
import re |
|
import sys |
|
import filecmp |
|
import logging |
|
import shutil |
|
import sysconfig |
|
import datetime |
|
import platform |
|
import pkg_resources |
|
|
|
errors = 0 |
|
log = logging.getLogger('sd') |
|
|
|
|
|
def setup_logging(clean=False): |
|
|
|
|
|
|
|
|
|
from rich.theme import Theme |
|
from rich.logging import RichHandler |
|
from rich.console import Console |
|
from rich.pretty import install as pretty_install |
|
from rich.traceback import install as traceback_install |
|
|
|
console = Console( |
|
log_time=True, |
|
log_time_format='%H:%M:%S-%f', |
|
theme=Theme( |
|
{ |
|
'traceback.border': 'black', |
|
'traceback.border.syntax_error': 'black', |
|
'inspect.value.border': 'black', |
|
} |
|
), |
|
) |
|
|
|
|
|
|
|
current_datetime = datetime.datetime.now() |
|
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S') |
|
log_file = os.path.join( |
|
os.path.dirname(__file__), |
|
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log', |
|
) |
|
|
|
|
|
log_directory = os.path.dirname(log_file) |
|
os.makedirs(log_directory, exist_ok=True) |
|
|
|
level = logging.INFO |
|
logging.basicConfig( |
|
level=logging.ERROR, |
|
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', |
|
filename=log_file, |
|
filemode='a', |
|
encoding='utf-8', |
|
force=True, |
|
) |
|
log.setLevel( |
|
logging.DEBUG |
|
) |
|
pretty_install(console=console) |
|
traceback_install( |
|
console=console, |
|
extra_lines=1, |
|
width=console.width, |
|
word_wrap=False, |
|
indent_guides=False, |
|
suppress=[], |
|
) |
|
rh = RichHandler( |
|
show_time=True, |
|
omit_repeated_times=False, |
|
show_level=True, |
|
show_path=False, |
|
markup=False, |
|
rich_tracebacks=True, |
|
log_time_format='%H:%M:%S-%f', |
|
level=level, |
|
console=console, |
|
) |
|
rh.set_name(level) |
|
while log.hasHandlers() and len(log.handlers) > 0: |
|
log.removeHandler(log.handlers[0]) |
|
log.addHandler(rh) |
|
|
|
|
|
def configure_accelerate(run_accelerate=False): |
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
def env_var_exists(var_name): |
|
return var_name in os.environ and os.environ[var_name] != '' |
|
|
|
log.info('Configuring accelerate...') |
|
|
|
source_accelerate_config_file = os.path.join( |
|
os.path.dirname(os.path.abspath(__file__)), |
|
'..', |
|
'config_files', |
|
'accelerate', |
|
'default_config.yaml', |
|
) |
|
|
|
if not os.path.exists(source_accelerate_config_file): |
|
if run_accelerate: |
|
run_cmd('accelerate config') |
|
else: |
|
log.warning( |
|
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.' |
|
) |
|
|
|
log.debug( |
|
f'Source accelerate config location: {source_accelerate_config_file}' |
|
) |
|
|
|
target_config_location = None |
|
|
|
log.debug( |
|
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, " |
|
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, " |
|
f"USERPROFILE: {os.environ.get('USERPROFILE')}" |
|
) |
|
if env_var_exists('HF_HOME'): |
|
target_config_location = Path( |
|
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml' |
|
) |
|
elif env_var_exists('LOCALAPPDATA'): |
|
target_config_location = Path( |
|
os.environ['LOCALAPPDATA'], |
|
'huggingface', |
|
'accelerate', |
|
'default_config.yaml', |
|
) |
|
elif env_var_exists('USERPROFILE'): |
|
target_config_location = Path( |
|
os.environ['USERPROFILE'], |
|
'.cache', |
|
'huggingface', |
|
'accelerate', |
|
'default_config.yaml', |
|
) |
|
|
|
log.debug(f'Target config location: {target_config_location}') |
|
|
|
if target_config_location: |
|
if not target_config_location.is_file(): |
|
target_config_location.parent.mkdir(parents=True, exist_ok=True) |
|
log.debug( |
|
f'Target accelerate config location: {target_config_location}' |
|
) |
|
shutil.copyfile( |
|
source_accelerate_config_file, target_config_location |
|
) |
|
log.info( |
|
f'Copied accelerate config file to: {target_config_location}' |
|
) |
|
else: |
|
if run_accelerate: |
|
run_cmd('accelerate config') |
|
else: |
|
log.warning( |
|
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' |
|
) |
|
else: |
|
if run_accelerate: |
|
run_cmd('accelerate config') |
|
else: |
|
log.warning( |
|
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' |
|
) |
|
|
|
|
|
def check_torch(): |
|
|
|
|
|
|
|
|
|
|
|
if shutil.which('nvidia-smi') is not None or os.path.exists( |
|
os.path.join( |
|
os.environ.get('SystemRoot') or r'C:\Windows', |
|
'System32', |
|
'nvidia-smi.exe', |
|
) |
|
): |
|
log.info('nVidia toolkit detected') |
|
elif shutil.which('rocminfo') is not None or os.path.exists( |
|
'/opt/rocm/bin/rocminfo' |
|
): |
|
log.info('AMD toolkit detected') |
|
else: |
|
log.info('Using CPU-only Torch') |
|
|
|
try: |
|
import torch |
|
|
|
log.info(f'Torch {torch.__version__}') |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
log.warning('Torch reports CUDA not available') |
|
else: |
|
if torch.version.cuda: |
|
|
|
log.info( |
|
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' |
|
) |
|
elif torch.version.hip: |
|
|
|
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') |
|
else: |
|
log.warning('Unknown Torch backend') |
|
|
|
|
|
for device in [ |
|
torch.cuda.device(i) for i in range(torch.cuda.device_count()) |
|
]: |
|
log.info( |
|
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' |
|
) |
|
return int(torch.__version__[0]) |
|
except Exception as e: |
|
|
|
return 0 |
|
|
|
|
|
|
|
def check_repo_version(): |
|
if os.path.exists('.release'): |
|
with open(os.path.join('./.release'), 'r', encoding='utf8') as file: |
|
release= file.read() |
|
|
|
log.info(f'Version: {release}') |
|
else: |
|
log.debug('Could not read release...') |
|
|
|
|
|
def git(arg: str, folder: str = None, ignore: bool = False): |
|
|
|
|
|
|
|
|
|
git_cmd = os.environ.get('GIT', "git") |
|
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') |
|
txt = result.stdout.decode(encoding="utf8", errors="ignore") |
|
if len(result.stderr) > 0: |
|
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") |
|
txt = txt.strip() |
|
if result.returncode != 0 and not ignore: |
|
global errors |
|
errors += 1 |
|
log.error(f'Error running git: {folder} / {arg}') |
|
if 'or stash them' in txt: |
|
log.error(f'Local changes detected: check log for details...') |
|
log.debug(f'Git output: {txt}') |
|
|
|
|
|
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): |
|
|
|
if not quiet: |
|
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') |
|
log.debug(f"Running pip: {arg}") |
|
if show_stdout: |
|
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ) |
|
else: |
|
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
txt = result.stdout.decode(encoding="utf8", errors="ignore") |
|
if len(result.stderr) > 0: |
|
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") |
|
txt = txt.strip() |
|
if result.returncode != 0 and not ignore: |
|
global errors |
|
errors += 1 |
|
log.error(f'Error running pip: {arg}') |
|
log.debug(f'Pip output: {txt}') |
|
return txt |
|
|
|
|
|
def installed(package, friendly: str = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
package = re.sub(r'\[.*?\]', '', package) |
|
|
|
try: |
|
if friendly: |
|
pkgs = friendly.split() |
|
else: |
|
pkgs = [ |
|
p |
|
for p in package.split() |
|
if not p.startswith('-') and not p.startswith('=') |
|
] |
|
pkgs = [ |
|
p.split('/')[-1] for p in pkgs |
|
] |
|
|
|
for pkg in pkgs: |
|
if '>=' in pkg: |
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] |
|
elif '==' in pkg: |
|
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] |
|
else: |
|
pkg_name, pkg_version = pkg.strip(), None |
|
|
|
spec = pkg_resources.working_set.by_key.get(pkg_name, None) |
|
if spec is None: |
|
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) |
|
if spec is None: |
|
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) |
|
|
|
if spec is not None: |
|
version = pkg_resources.get_distribution(pkg_name).version |
|
log.debug(f'Package version found: {pkg_name} {version}') |
|
|
|
if pkg_version is not None: |
|
if '>=' in pkg: |
|
ok = version >= pkg_version |
|
else: |
|
ok = version == pkg_version |
|
|
|
if not ok: |
|
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}') |
|
return False |
|
else: |
|
log.debug(f'Package version not found: {pkg_name}') |
|
return False |
|
|
|
return True |
|
except ModuleNotFoundError: |
|
log.debug(f'Package not installed: {pkgs}') |
|
return False |
|
|
|
|
|
|
|
def install( |
|
|
|
|
|
|
|
package, |
|
friendly: str = None, |
|
ignore: bool = False, |
|
reinstall: bool = False, |
|
show_stdout: bool = False, |
|
): |
|
|
|
package = package.split('#')[0].strip() |
|
|
|
if reinstall: |
|
global quick_allowed |
|
quick_allowed = False |
|
if reinstall or not installed(package, friendly): |
|
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout) |
|
|
|
|
|
|
|
def process_requirements_line(line, show_stdout: bool = False): |
|
|
|
|
|
package_name = re.sub(r'\[.*?\]', '', line) |
|
install(line, package_name, show_stdout=show_stdout) |
|
|
|
|
|
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False): |
|
if check_no_verify_flag: |
|
log.info(f'Verifying modules instalation status from {requirements_file}...') |
|
else: |
|
log.info(f'Installing modules from {requirements_file}...') |
|
with open(requirements_file, 'r', encoding='utf8') as f: |
|
|
|
if check_no_verify_flag: |
|
lines = [ |
|
line.strip() |
|
for line in f.readlines() |
|
if line.strip() != '' |
|
and not line.startswith('#') |
|
and line is not None |
|
and 'no_verify' not in line |
|
] |
|
else: |
|
lines = [ |
|
line.strip() |
|
for line in f.readlines() |
|
if line.strip() != '' |
|
and not line.startswith('#') |
|
and line is not None |
|
] |
|
|
|
|
|
for line in lines: |
|
|
|
if line.startswith('-r'): |
|
|
|
included_file = line[2:].strip() |
|
|
|
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout) |
|
else: |
|
process_requirements_line(line, show_stdout=show_stdout) |
|
|
|
|
|
def ensure_base_requirements(): |
|
try: |
|
import rich |
|
except ImportError: |
|
install('--upgrade rich', 'rich') |
|
|
|
|
|
def run_cmd(run_cmd): |
|
try: |
|
subprocess.run(run_cmd, shell=True, check=False, env=os.environ) |
|
except subprocess.CalledProcessError as e: |
|
print(f'Error occurred while running command: {run_cmd}') |
|
print(f'Error: {e}') |
|
|
|
|
|
|
|
def check_python(ignore=True, skip_git=False): |
|
|
|
|
|
|
|
|
|
supported_minors = [9, 10] |
|
log.info(f'Python {platform.python_version()} on {platform.system()}') |
|
if not ( |
|
int(sys.version_info.major) == 3 |
|
and int(sys.version_info.minor) in supported_minors |
|
): |
|
log.error( |
|
f'Incompatible Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.{supported_minors}' |
|
) |
|
if not ignore: |
|
sys.exit(1) |
|
if not skip_git: |
|
git_cmd = os.environ.get('GIT', 'git') |
|
if shutil.which(git_cmd) is None: |
|
log.error('Git not found') |
|
if not ignore: |
|
sys.exit(1) |
|
else: |
|
git_version = git('--version', folder=None, ignore=False) |
|
log.debug(f'Git {git_version.replace("git version", "").strip()}') |
|
|
|
|
|
def delete_file(file_path): |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
|
|
def write_to_file(file_path, content): |
|
try: |
|
with open(file_path, 'w') as file: |
|
file.write(content) |
|
except IOError as e: |
|
print(f'Error occurred while writing to file: {file_path}') |
|
print(f'Error: {e}') |
|
|
|
|
|
def clear_screen(): |
|
|
|
if os.name == 'nt': |
|
os.system('cls') |
|
else: |
|
os.system('clear') |
|
|
|
|