import os
import sys
import time
import json
import platform
import subprocess
import datetime
import logging
from html.parser import HTMLParser
import torch
import gradio as gr
from modules import paths, script_callbacks, sd_models, sd_samplers, shared, extensions, devices
from benchmark import run_benchmark, submit_benchmark # pylint: disable=E0401,E0611,C0411
### system info globals
log = logging.getLogger('sd')
data = {
'date': '',
'timestamp': '',
'uptime': '',
'version': {},
'torch': '',
'gpu': {},
'state': {},
'memory': {},
'optimizations': [],
'libs': {},
'repos': {},
'device': {},
'schedulers': [],
'extensions': [],
'platform': '',
'crossattention': '',
'backend': getattr(devices, 'backend', ''),
'pipeline': shared.opts.data.get('sd_backend', ''),
'model': {},
}
networks = {
'models': [],
'hypernetworks': [],
'embeddings': [],
'skipped': [],
'loras': [],
'lycos': [],
}
### benchmark globals
bench_text = ''
bench_file = os.path.join(os.path.dirname(__file__), 'benchmark-data-local.json')
bench_headers = ['timestamp', 'performance', 'version', 'system', 'libraries', 'gpu', 'pipeline', 'model', 'username', 'note', 'hash']
bench_data = []
### system info module
def get_user():
user = ''
if user == '':
try:
user = os.getlogin()
except Exception:
pass
if user == '':
try:
import pwd
user = pwd.getpwuid(os.getuid())[0]
except Exception:
pass
return user
def get_gpu():
if not torch.cuda.is_available():
try:
if shared.cmd_opts.use_openvino:
from modules.intel.openvino import get_openvino_device
return {
'device': get_openvino_device(),
'openvino': get_package_version("openvino")
}
else:
return {}
except Exception:
return {}
else:
try:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
elif torch.version.cuda:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())}) ({torch.cuda.get_arch_list()[-1]}) {str(torch.cuda.get_device_capability(shared.device))}',
'cuda': torch.version.cuda,
'cudnn': torch.backends.cudnn.version(),
'driver': get_driver(),
}
elif torch.version.hip:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())})',
'hip': torch.version.hip,
}
else:
return {
'device': 'unknown'
}
except Exception as e:
return { 'error': e }
def get_driver():
if torch.cuda.is_available() and torch.version.cuda:
try:
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
return version
except Exception:
return ''
else:
return ''
def get_uptime():
s = vars(shared.state)
return time.strftime('%c', time.localtime(s.get('server_start', time.time())))
class HTMLFilter(HTMLParser):
text = ""
def handle_data(self, data): # pylint: disable=redefined-outer-name
self.text += data
def get_state():
s = vars(shared.state)
flags = 'skipped ' if s.get('skipped', False) else ''
flags += 'interrupted ' if s.get('interrupted', False) else ''
flags += 'needs restart' if s.get('need_restart', False) else ''
text = s.get('textinfo', '')
if text is not None and len(text) > 0:
f = HTMLFilter()
f.feed(text)
text = os.linesep.join([s for s in f.text.splitlines() if s])
return {
'started': time.strftime('%c', time.localtime(s.get('time_start', time.time()))),
'step': f'{s.get("sampling_step", 0)} / {s.get("sampling_steps", 0)}',
'jobs': f'{s.get("job_no", 0)} / {s.get("job_count", 0)}', # pylint: disable=consider-using-f-string
'flags': flags,
'job': s.get('job', ''),
'text-info': text,
}
def get_memory():
def gb(val: float):
return round(val / 1024 / 1024 / 1024, 2)
mem = {}
try:
import psutil
process = psutil.Process(os.getpid())
res = process.memory_info()
ram_total = 100 * res.rss / process.memory_percent()
ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) }
mem.update({ 'ram': ram })
except Exception as e:
mem.update({ 'ram': e })
if torch.cuda.is_available():
try:
s = torch.cuda.mem_get_info()
gpu = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) }
s = dict(torch.cuda.memory_stats(shared.device))
allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) }
reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) }
active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) }
inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) }
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
mem.update({
'gpu': gpu,
'gpu-active': active,
'gpu-allocated': allocated,
'gpu-reserved': reserved,
'gpu-inactive': inactive,
'events': warnings,
'utilization': 0,
})
mem.update({ 'utilization': torch.cuda.utilization() }) # do this one separately as it may fail
except Exception:
pass
else:
try:
from openvino.runtime import Core as OpenVINO_Core
from modules.intel.openvino import get_device as get_raw_openvino_device
openvino_core = OpenVINO_Core()
mem.update({
'gpu': { 'total': gb(openvino_core.get_property(get_raw_openvino_device(), 'GPU_DEVICE_TOTAL_MEM_SIZE')) },
})
except Exception:
pass
return mem
def get_optimizations():
ram = []
if getattr(shared.cmd_opts, 'medvram', False):
ram.append('medvram')
if getattr(shared.cmd_opts, 'medvram_sdxl', False):
ram.append('medvram-sdxl')
if getattr(shared.cmd_opts, 'lowvram', False):
ram.append('lowvram')
if getattr(shared.cmd_opts, 'lowvam', False):
ram.append('lowram')
if len(ram) == 0:
ram.append('none')
return ram
def get_package_version(pkg: str):
import pkg_resources
spec = pkg_resources.working_set.by_key.get(pkg, None) # more reliable than importlib
version = pkg_resources.get_distribution(pkg).version if spec is not None else ''
return version
def get_libs():
return {
'xformers': get_package_version('xformers'),
'diffusers': get_package_version('diffusers'),
'transformers': get_package_version('transformers'),
}
def get_repos():
repos = {}
for key, val in paths.paths.items():
try:
cmd = f'git -C {val} log --pretty=format:"%h %ad" -1 --date=short'
res = subprocess.run(f'{cmd} {val}', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
stdout = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
words = stdout.split(' ')
repos[key] = f'[{words[0]}] {words[1]}'
except Exception:
repos[key] = '(unknown)'
return repos
def get_platform():
try:
if platform.system() == 'Windows':
release = platform.platform(aliased = True, terse = False)
else:
release = platform.release()
return {
# 'host': platform.node(),
'arch': platform.machine(),
'cpu': platform.processor(),
'system': platform.system(),
'release': release,
# 'platform': platform.platform(aliased = True, terse = False),
# 'version': platform.version(),
'python': platform.python_version(),
}
except Exception as e:
return { 'error': e }
def get_torch():
try:
ver = torch.__long_version__
except Exception:
ver = torch.__version__
return f"{ver} {shared.cmd_opts.precision} {' nohalf' if shared.cmd_opts.no_half else ' half'}"
def get_version():
version = {}
try:
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
githash, updated = ver.split(' ')
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
url = origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
app = origin.replace('\n', '').split('/')[-1]
if app == 'automatic':
app = 'SD.next'
version = {
'app': app,
'updated': updated,
'hash': githash,
'url': url
}
except Exception:
pass
return version
def get_crossattention():
try:
ca = getattr(shared.opts, 'cross_attention_optimization', None)
if ca is None:
from modules import sd_hijack
ca = sd_hijack.model_hijack.optimization_method
return ca
except Exception:
return 'unknown'
def get_model():
from modules.sd_models import model_data
import modules.sd_vae
obj = {
'configured': {
'base': shared.opts.data.get('sd_model_checkpoint', ''),
'refiner': shared.opts.data.get('sd_model_refiner', ''),
'vae': shared.opts.data.get('sd_vae', ''),
},
'loaded': {
'base': '',
'refiner': '',
'vae': '',
}
}
try:
obj['loaded']['base'] = model_data.sd_model.sd_checkpoint_info.filename if model_data.sd_model is not None and hasattr(model_data.sd_model, 'sd_checkpoint_info') else ''
except Exception :
pass
try:
obj['loaded']['refiner'] = model_data.sd_refiner.sd_checkpoint_info.filename if model_data.sd_refiner is not None and hasattr(model_data.sd_refiner, 'sd_checkpoint_info') else ''
except Exception :
pass
try:
obj['loaded']['vae'] = modules.sd_vae.loaded_vae_file
except Exception:
pass
return obj
def get_models():
return sorted([x.title for x in sd_models.checkpoints_list.values()])
def get_samplers():
return sorted([sampler[0] for sampler in sd_samplers.all_samplers])
def get_extensions():
return sorted([f"{e.name} ({'enabled' if e.enabled else 'disabled'}{' builtin' if e.is_builtin else ''})" for e in extensions.extensions])
def get_loras():
loras = []
try:
sys.path.append(extensions.extensions_builtin_dir)
from Lora import lora # pylint: disable=E0401
loras = sorted([l for l in lora.available_loras.keys()])
except Exception:
pass
return loras
def get_device():
dev = {
'active': str(devices.device),
'dtype': str(devices.dtype),
'vae': str(devices.dtype_vae),
'unet': str(devices.dtype_unet),
}
return dev
def get_full_data():
global data # pylint: disable=global-statement
data = {
'date': datetime.datetime.now().strftime('%c'),
'timestamp': datetime.datetime.now().strftime('%X'),
'uptime': get_uptime(),
'version': get_version(),
'torch': get_torch(),
'gpu': get_gpu(),
'state': get_state(),
'memory': get_memory(),
'optimizations': get_optimizations(),
'libs': get_libs(),
'repos': get_repos(),
'device': get_device(),
'model': get_model(),
'schedulers': get_samplers(),
'extensions': get_extensions(),
'platform': get_platform(),
'crossattention': get_crossattention(),
'backend': getattr(devices, 'backend', ''),
'pipeline': shared.opts.data.get('sd_backend', ''),
}
global networks # pylint: disable=global-statement
networks = {
'models': get_models(),
'loras': get_loras(),
}
return data
def get_quick_data():
data['timestamp'] = datetime.datetime.now().strftime('%X')
data['state'] = get_state()
data['memory'] = get_memory()
data['model'] = get_model()
def list2text(lst: list):
return '\n'.join(lst)
def dict2str(d: dict):
arr = [f'{name}:{d[name]}' for i, name in enumerate(d)]
return ' '.join(arr)
def dict2text(d: dict):
arr = ['{name}: {val}'.format(name = name, val = d[name] if not type(d[name]) is dict else dict2str(d[name])) for i, name in enumerate(d)] # pylint: disable=consider-using-f-string
return list2text(arr)
def refresh_info_quick(_old_data = None):
get_quick_data()
return dict2text(data['state']), dict2text(data['memory']), data['crossattention'], data['timestamp'], data
def refresh_info_full():
get_full_data()
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['optimizations']), data['crossattention'], data['backend'], data['pipeline'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), dict2text(data['model']), networks['models'], networks['loras'], data['timestamp'], data
### ui definition
def create_ui(blocks: gr.Blocks = None):
try:
if shared.cmd_opts.api_only:
return
except:
pass
if not standalone:
from modules.ui import ui_system_tabs # pylint: disable=redefined-outer-name
else:
ui_system_tabs = None
with gr.Blocks(analytics_enabled = False) if standalone else blocks as system_info:
with gr.Row(elem_id = 'system_info'):
with gr.Tabs(elem_id = 'system_info_tabs') if standalone else ui_system_tabs:
with gr.TabItem('System Info'):
with gr.Row():
timestamp = gr.Textbox(value=data['timestamp'], label = '', elem_id = 'system_info_tab_last_update', container=False)
refresh_quick_btn = gr.Button('Refresh state', elem_id = 'system_info_tab_refresh_btn', visible = False) # quick refresh is used from js interval
refresh_full_btn = gr.Button('Refresh data', elem_id = 'system_info_tab_refresh_full_btn', variant='primary')
interrupt_btn = gr.Button('Send interrupt', elem_id = 'system_info_tab_interrupt_btn', variant='primary')
with gr.Row():
with gr.Column():
uptimetxt = gr.Textbox(data['uptime'], label = 'Server start time', lines = 1)
versiontxt = gr.Textbox(dict2text(data['version']), label = 'Version', lines = len(data['version']))
with gr.Column():
statetxt = gr.Textbox(dict2text(data['state']), label = 'State', lines = len(data['state']))
with gr.Column():
memorytxt = gr.Textbox(dict2text(data['memory']), label = 'Memory', lines = len(data['memory']))
with gr.Row():
with gr.Column():
platformtxt = gr.Textbox(dict2text(data['platform']), label = 'Platform', lines = len(data['platform']))
with gr.Row():
backendtxt = gr.Textbox(data['backend'], label = 'Backend')
pipelinetxt = gr.Textbox(data['pipeline'], label = 'Pipeline')
with gr.Column():
torchtxt = gr.Textbox(data['torch'], label = 'Torch', lines = 1)
gputxt = gr.Textbox(dict2text(data['gpu']), label = 'GPU', lines = len(data['gpu']))
with gr.Row():
opttxt = gr.Textbox(list2text(data['optimizations']), label = 'Memory optimization')
attentiontxt = gr.Textbox(data['crossattention'], label = 'Cross-attention')
with gr.Column():
libstxt = gr.Textbox(dict2text(data['libs']), label = 'Libs', lines = len(data['libs']))
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']), visible = False)
devtxt = gr.Textbox(dict2text(data['device']), label = 'Device Info', lines = len(data['device']))
modeltxt = gr.Textbox(dict2text(data['model']), label = 'Model Info', lines = len(data['model']))
with gr.Row():
gr.HTML('Load