In [None]:
update_desc = '''
欢迎使用 stable-diffusion-webui 便捷启动工具

此脚本理论上可以在任何jupyter环境运行，但仅在 google colab 和 kaggle 测试。 已经无法在 google colab 免费实例上运行。
此脚本的【前置脚本】发布地址 https://www.kaggle.com/code/yiyiooo/stable-diffusion-webui-novelai-sdxl。
此脚本需配合 【前置脚本】 的脚本附带的配置项才能正常启动。
此脚本的内容会自动更新，你无需更新【前置脚本】就能获取到最新的功能。
增加一个说明：为解决加密图片的解密和浏览的问题，开发了一个独立应用，可以在这里下载 https://github.com/viyiviyi/encrypt_gallery/releases/
加密插件可解决生成nsfw图片后被平台封号问题

路径说明
*  为了解决平台差异，所有的安装目录和文件输出目录都被重新指定，如果你需要在配置中访问这些目录，请查看以下说明
*
*  可以使用 $install_path 或 {install_path} 来访问安装目录，写在字符串内也会生效
*  $output_path 或 {output_path} 可以访问输出目录
*  如果你需要安装在自定义目录，也可以设置这些值 如: install_path = '新的路径'
*  可自定义方法 on_before_start 并在方法内写上启动前需要的逻辑来实现在webui启动前执行自定义逻辑
*  可以增加配置 multi_case = True 来控制是否使用多卡
*  如果需要显示更多控制台输出 需配置 hidden_console_info = False
'''

In [None]:
from pathlib import Path
import os
import time
import re
import subprocess
import threading
import sys
import socket
import torch
from typing import List

In [None]:
# 内置参数默认值，当上下文有参数时可覆盖默认值
_runing = False
_useFrpc = locals().get('useFrpc') or globals().get('useFrpc') or True

_useNgrok = locals().get('useNgrok') or globals().get('useNgrok') or True

_reLoad = locals().get('reLoad') or globals().get('reLoad') or False
    
_before_downloading = locals().get('before_downloading') or globals().get('before_downloading') or ''

_async_downloading = locals().get('async_downloading') or globals().get('async_downloading') or ''

_before_start_sync_downloading = locals().get('before_start_sync_downloading') or globals().get('before_start_sync_downloading') or  ''

_server_port = locals().get('server_port') or globals().get('server_port') or 7860
    
_sd_git_repo = locals().get('sd_git_repo') or globals().get('sd_git_repo') or 'https://github.com/viyiviyi/stable-diffusion-webui.git -b local' 
_sd_git_repo = _sd_git_repo\
    .replace('{sdwui}','stable-diffusion-webui')\
    .replace('{wui}',"webui")
    
_sd_config_git_repu = locals().get('sd_config_git_repu') or globals().get('sd_config_git_repu') or 'https://github.com/viyiviyi/sd-configs.git'
_sd_config_git_repu = _sd_config_git_repu\
    .replace('{sdwui}','stable-diffusion-webui')\
    .replace('{wui}',"webui")
    
    
_huggingface_token = locals().get('huggingface_token') or globals().get('huggingface_token') or '{input_path}/configs/huggingface_token.txt'
_huggingface_token = _huggingface_token\
    .replace('{sdwui}','stable-diffusion-webui')\
    .replace('{wui}',"webui")
    
_huggingface_repo = locals().get('huggingface_repo') or globals().get('huggingface_repo') or ''
_huggingface_repo = _huggingface_repo\
    .replace('{sdwui}','stable-diffusion-webui')\
    .replace('{wui}',"webui")

_link_instead_of_copy = locals().get('link_instead_of_copy') or globals().get('link_instead_of_copy') or True
    
show_shell_info = locals().get('hidden_console_info') or globals().get('hidden_console_info')
if show_shell_info is None: show_shell_info = False

_multi_case = locals().get('multi_case') or globals().get('multi_case') or False
    
_skip_start = locals().get('skip_start') or globals().get('skip_start') or True

def before_start():
    ''''''

_on_before_start = locals().get('on_before_start') or globals().get('on_before_start') or before_start 
    
run_by_none_device = False

_proxy_path = locals().get('proxy_path') or globals().get('proxy_path') or {}

_sub_path =  locals().get('sub_path') or globals().get('sub_path') or ['/','/1/']
if len(_sub_path) != 2:
    _sub_path = ['/','/1/']
    
_config_args:dict[str, str] =  locals().get('config_args') or globals().get('config_args') or {}

In [None]:

def run(command, cwd=None, desc=None, errdesc=None, custom_env=None,try_error:bool=True) -> str:
    global show_shell_info
    if desc is not None:
        print(desc)

    run_kwargs = {
        "args": command,
        "shell": True,
        "cwd": cwd,
        "env": os.environ if custom_env is None else custom_env,
        "encoding": 'utf8',
        "errors": 'ignore',
    }

    if not show_shell_info:
        run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE

    result = subprocess.run(**run_kwargs)

    if result.returncode != 0:
        error_bits = [
            f"{errdesc or 'Error running command'}.",
            f"Command: {command}",
            f"Error code: {result.returncode}",
        ]
        if result.stdout:
            error_bits.append(f"stdout: {result.stdout}")
        if result.stderr:
            error_bits.append(f"stderr: {result.stderr}")
        if try_error:
            print((RuntimeError("\n".join(error_bits))))
        else:
            raise RuntimeError("\n".join(error_bits))

    if show_shell_info:
        print((result.stdout or ""))
    return (result.stdout or "")

def mkdirs(path, exist_ok=True):
    if path and not Path(path).exists():
        os.makedirs(path,exist_ok=exist_ok)

In [None]:

# 检查gpu是否存在
def check_gpu():
    if not run_by_none_device and torch.cuda.device_count() == 0:
        raise Exception('当前环境没有GPU')

_install_path = f"{os.environ['HOME']}/sd_webui" # 安装目录
_output_path = '/kaggle/working' if os.path.exists('/kaggle/working/') else f"{os.environ['HOME']}/.sdwui/Output" # 输出目录 如果使用google云盘 会在google云盘增加sdwebui/Output
_input_path = '/kaggle/input' # 输入目录
_ui_dir_name = 'sd_main_dir'

_install_path = locals().get('install_path') or globals().get('install_path') or _install_path
_output_path = locals().get('output_path') or globals().get('output_path') or _output_path
_input_path = locals().get('input_path') or globals().get('input_path') or _input_path
_ui_dir_name = locals().get('ui_dir_name') or globals().get('ui_dir_name') or _ui_dir_name

install_path = _install_path
output_path = _output_path
input_path = _input_path
ui_dir_name = _ui_dir_name
    
google_drive = '' 


_useGooglrDrive = locals().get('useGooglrDrive') or globals().get('useGooglrDrive') or True

# 连接谷歌云
try:
    if _useGooglrDrive:
        from google.colab import drive
        drive.mount(f'~/google_drive')
        google_drive = f"{os.environ['HOME']}/google_drive/MyDrive"
        _output_path = f'{google_drive}/sdwebui/Output'
        _input_path = f'{google_drive}/sdwebui/Input'
        run(f'''mkdir -p {_input_path}''')
        print('''
已经链接到谷歌云盘
已在云盘创建Input和Output目录
        ''')
except:
    _useGooglrDrive = False

run(f'''mkdir -p {_install_path}''')
run(f'''mkdir -p {_output_path}''')


os.environ['install_path'] = _install_path
os.environ['output_path'] = _output_path
os.environ['google_drive'] = google_drive
os.environ['input_path'] = _input_path

def replace_path(input_str:str):
    if not input_str: return ''
    for key in _config_args:
        input_str = input_str.replace(key,_config_args[key])
        
    return input_str.replace('$install_path',_install_path)\
    .replace('{install_path}',_install_path)\
    .replace('$input_path',_input_path)\
    .replace('{input_path}',_input_path)\
    .replace('$output_path',_output_path)\
    .replace('{output_path}',_output_path)\
    .replace('{sdwui}','stable-diffusion-webui')\
    .replace('{wui}',"webui")

space_string = ' \n\r\t\'\",'

def config_reader(conf:str):
    args = [replace_path(item.split('#')[0].strip(space_string)) for item in conf.split('\n') if item.strip(space_string)]
    return [item.strip() for item in args if item.strip()]


In [None]:
ngrokTokenFile = os.path.join(_input_path,'configs/ngrok_token.txt') # 非必填 存放ngrokToken的文件的路径
frpcConfigFile = os.path.join(_input_path,'configs/frpc_koishi.ini') # 非必填 frp 配置文件
# ss证书目录 下载nginx的版本，把pem格式改成crt格式
frpcSSLFFlies = [os.path.join(_input_path,'configs/koishi_ssl')]
if 'frp_ssl_dir' in locals() or 'frp_ssl_dir' in globals():
    frpcSSLFFlies = frpcSSLFFlies + config_reader(locals().get('frp_ssl_dir') or globals().get('frp_ssl_dir'))
# frpc 文件目录 如果目录不存在，会自动下载，也可以在数据集搜索 viyiviyi/utils 添加
frpcExePath = os.path.join(_input_path,'utils-tools/frpc')
# 其他需要加载的webui启动参数 写到【参数列表】这个配置去
otherArgs = '--xformers'
if 'sd_start_args' in locals() or 'sd_start_args' in globals():
    otherArgs = ' '.join([item for item in config_reader(locals().get('sd_start_args') or globals().get('sd_start_args')) if item != '--no-gradio-queue'])
venvPath = os.path.join(_input_path,'sd-webui-venv/venv.tar.bak') # 安装好的python环境 sd-webui-venv是一个公开是数据集 可以搜索添加

# 用于使用kaggle api的token文件 参考 https://www.kaggle.com/docs/api
# 此文件用于自动上传koishi的相关配置 也可以用于保存重要的输出文件
kaggleApiTokenFile = os.path.join(_input_path,'configs/kaggle.json')

requirements = []


In [None]:
# 这下面的是用于初始化一些值或者环境变量的，轻易别改
_setting_file = replace_path(locals().get('setting_file') or globals().get('setting_file') or '/kaggle/working/configs/config.json')

_ui_config_file = replace_path(locals().get('ui_config_file') or globals().get('ui_config_file') or '/kaggle/working/configs/ui-config.json')

# 设置文件路径
if Path(f"{os.environ['HOME']}/google_drive/MyDrive").exists():
    if _setting_file == '/kaggle/working/configs/config.json':
        _setting_file = os.path.join(_output_path,'configs/config.json')
    if _ui_config_file == '/kaggle/working/configs/ui-config.json':
        _ui_config_file = os.path.join(_output_path,'configs/ui-config.json')
    
frpcStartArg = ''
_frp_config_or_file = replace_path(locals().get('frp_config_or_file') or globals().get('frp_config_or_file')) or frpcConfigFile
run(f'''mkdir -p {_install_path}/configFiles''')
if _frp_config_or_file:
    if Path(_frp_config_or_file.strip()).exists():
        frpcConfigFile = _frp_config_or_file.strip()
    if not Path(frpcConfigFile).exists(): 
        if _frp_config_or_file.strip().startswith('-f'):
            frpcStartArg = _frp_config_or_file.strip()
        else:
            print('没有frpcp配置')
            _useFrpc = False
    else:
        run(f'''cp -f {frpcConfigFile} {_install_path}/configFiles/frpc_webui.ini''')
        frpcConfigFile = f'{_install_path}/configFiles/frpc_webui.ini'
        run(f'''sed -i "s/local_port = .*/local_port = {_server_port}/g" {frpcConfigFile}''')
        frpcStartArg = f' -c {frpcConfigFile}'

ngrokToken=''
_ngrok_config_or_file = replace_path(locals().get('ngrok_config_or_file') or globals().get('ngrok_config_or_file')) or ngrokTokenFile
if _ngrok_config_or_file:
    if Path(_ngrok_config_or_file.strip()).exists():
        ngrokTokenFile = _ngrok_config_or_file.strip()
    if Path(ngrokTokenFile).exists():
        with open(ngrokTokenFile,encoding = "utf-8") as nkfile:
            ngrokToken = nkfile.readline()
    elif not _ngrok_config_or_file.strip().startswith('/'):
        ngrokToken=_ngrok_config_or_file.strip()
    
if not Path(venvPath).exists():
    venvPath = os.path.join(_input_path,'sd-webui-venv/venv.zip')
    
huggingface_headers:dict = None  

## 文件下载工具

---

link_or_download_flie(config:str, skip_url:bool=False, _link_instead_of_copy:bool=True, base_path:str = '',sync:bool=False,thread_num:int=None)

In [None]:
import concurrent.futures
import importlib
import os
import pprint
import re
from pathlib import Path
from typing import List

import requests

show_shell_info = False

def is_installed(package):
    try:
        spec = importlib.util.find_spec(package)
    except ModuleNotFoundError:
        return False

    return spec is not None

def download_file(url:str, filename:str, dist_path:str, cache_path = '',_link_instead_of_copy:bool=True,headers={}):
    # 获取文件的真实文件名
    if not filename:
        with requests.get(url, stream=True,headers=headers) as r:
            if 'Content-Disposition' in r.headers:
                filename = r.headers['Content-Disposition'].split('filename=')[1].strip('"')
            r.close()
    if not filename and re.search(r'/[^/]+\.[^/]+$',url):
        filename = url.split('/')[-1].split('?')[0]
    
    filename = re.sub(r'[\\/:*?"<>|;]', '', filename)
    filename = re.sub(r'[\s\t]+', '_', filename)
    
    print(f'下载 {filename} url: {url} --> {dist_path}')
    
    # 创建目录
    if cache_path and not Path(cache_path).exists():
        os.makedirs(cache_path,exist_ok=True)
    if dist_path and not Path(dist_path).exists():
        os.makedirs(dist_path,exist_ok=True)
        
    # 拼接文件的完整路径
    filepath = os.path.join(dist_path, filename)

    if cache_path:
        cache_path = os.path.join(cache_path, filename)
        
    # 判断文件是否已存在
    if Path(filepath).exists():
        print(f'文件 {filename} 已存在 {dist_path}')
        return
    
    if cache_path and Path(cache_path).exists():
        run(f'cp -n -r -f {"-s" if _link_instead_of_copy else ""} {cache_path} {dist_path}')
        print(f'文件缓存 {cache_path} --> {dist_path}')
        return
    # 下载文件
    with requests.get(url, stream=True, headers=headers) as r:
        r.raise_for_status()
        with open(cache_path or filepath, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
    # 如果使用了缓存目录 需要复制或链接文件到目标目录
    if cache_path:
        run(f'cp -n -r -f {"-s" if _link_instead_of_copy else ""} {cache_path} {dist_path}')
    print(f'下载完成 {filename} --> {dist_path}')
        
def download_git(url, dist_path, cache_path = '',_link_instead_of_copy:bool=True):
    if not Path(dist_path).exists():
        os.makedirs(dist_path,exist_ok=True)
    if show_shell_info:
        print(f'git 下载 {url} --> {dist_path}')
    if cache_path and not Path(cache_path).exists():
        os.makedirs(cache_path,exist_ok=True)
        run(f'git clone {url}',cwd = cache_path)
    if cache_path:
        run(f'cp -n -r -f {cache_path}/* {dist_path}')
    else:
        run(f'git clone {url}',cwd = dist_path)
    print(f'git 下载完成 {url} --> {dist_path}')
    
    
def download_huggingface(url:str, filename:str, dist_path, cache_path = '',_link_instead_of_copy:bool=True):
    fileReg = r'^https:\/\/huggingface.co(\/([^\/]+\/)?[^\/]+\/[^\/]+\/(resolve|blob)\/[^\/]+\/|[^\.]+\.[^\.]+$|download=true)'
    def isFile(url:str):
        if re.match(fileReg,url):
            return True
        return False
    if isFile(url):
        download_file(url,filename,dist_path,cache_path,_link_instead_of_copy,headers=huggingface_headers)
    else:
        download_git(url,dist_path,cache_path,_link_instead_of_copy)
    
# 加入文件到下载列表
def pause_url(url:str,dist_path:str):
    file_name = ''
    if re.match(r'^[^:]+:(https?|ftps?)://', url, flags=0):
        file_name = re.findall(r'^[^:]+:',url)[0][:-1]
        url = url[len(file_name)+1:]
    if not re.match(r'^(https?|ftps?)://',url):
        return
    file_name = re.sub(r'\s+','_',file_name or '')
    path_hash = str(hash(url)).replace('-','')
    
    return {'file_name':file_name,'path_hash':path_hash,'url':url,'dist_path':dist_path}

def download_urls(download_list:List[dict],sync:bool=False,thread_num:int=5, 
                  cache_path:str=os.path.join(os.environ['HOME'],'.cache','download_util'),
                  _link_instead_of_copy:bool=True,is_await:bool=False):
    if sync:
        for conf in download_list:
            cache_dir = os.path.join(cache_path,conf['path_hash'])
            if conf['url'].startswith('https://github.com'):
                download_git(conf['url'],conf['dist_path'],cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy)
                continue
            if conf['url'].startswith('https://huggingface.co'):
                download_huggingface(conf['url'],conf['file_name'],conf['dist_path'],cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy)
                continue
            if conf['url'].startswith('https://civitai.com'):
                if not re.search(r'token=.+', conf['url']):
                    if conf['url'].find('?') == -1:
                        conf['url'] = conf['url']+'?token=fee8bb78b75566eddfd04d061996185c'
                    else:
                        conf['url'] = conf['url']+'&token=fee8bb78b75566eddfd04d061996185c'
            download_file(conf['url'],conf['file_name'],conf['dist_path'],cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy)
    else:
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=thread_num)
        futures = []
        for conf in download_list:
            cache_dir = os.path.join(cache_path,conf['path_hash'])
            if conf['url'].startswith('https://github.com'):
                futures.append(executor.submit(download_git, conf['url'],conf['dist_path'],
                                                cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy))
                continue
            if conf['url'].startswith('https://huggingface.co'):
                futures.append(executor.submit(download_huggingface,conf['url'],conf['file_name'],conf['dist_path'],cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy))
                continue
            if conf['url'].startswith('https://civitai.com'):
                if not re.search(r'token=.+', conf['url']):
                    if conf['url'].find('?') == -1:
                        conf['url'] = conf['url']+'?token=fee8bb78b75566eddfd04d061996185c'
                    else:
                        conf['url'] = conf['url']+'&token=fee8bb78b75566eddfd04d061996185c'
            futures.append(executor.submit(download_file, conf['url'],conf['file_name'],conf['dist_path'],
                                            cache_path=cache_dir,_link_instead_of_copy=_link_instead_of_copy))
        if is_await:
            concurrent.futures.wait(futures)
            
                          
def parse_config(config:str):
    space_string = ' \n\r\t\'\",'
    other_flie_list = [item.split('#')[0].strip(space_string) for item in config.split('\n') if item.strip(space_string)]
    other_flie_list = [item.strip() for item in other_flie_list if item.strip()]
    other_flie_list_store = {}
    other_flie_list_store_name='default'
    other_flie_list_store_list_cache=[]
    
    for item in other_flie_list:
        if item.startswith('[') and item.endswith(']'):
            if not other_flie_list_store_name == 'default':
                other_flie_list_store[other_flie_list_store_name]=other_flie_list_store_list_cache
                other_flie_list_store_list_cache = []
            other_flie_list_store_name = item[1:-1]
        else:
            other_flie_list_store_list_cache.append(item)
    other_flie_list_store[other_flie_list_store_name]=other_flie_list_store_list_cache
    
    return other_flie_list_store


def link_or_download_flie(config:str, skip_url:bool=False, _link_instead_of_copy:bool=True, base_path:str = '',
                          sync:bool=False,thread_num:int=None, is_await:bool=False):
    store:dict[str,List[str]] = parse_config(config)
    download_list = []
    for dist_dir in store.keys():
        dist_path = os.path.join(base_path,dist_dir)
        os.makedirs(dist_path,exist_ok=True)
        for path in store[dist_dir]:
            if 'https://' in path or 'http://' in path:
                if skip_url:
                    continue
                if sync:
                    download_urls([pause_url(path,dist_path)],_link_instead_of_copy = _link_instead_of_copy, sync=sync)
                    continue
                download_list.append(pause_url(path,dist_path))
            else:
                run(f'cp -n -r -f {"-s" if _link_instead_of_copy else ""} {path} {dist_path}')
                if show_shell_info:
                    print(f'{"链接" if _link_instead_of_copy else "复制"} {path} --> {dist_path}')
        run(f'rm -f {dist_path}/\*.* ')
    if not skip_url:
        if show_shell_info:
            pprint.pprint(download_list)
        download_urls(download_list,_link_instead_of_copy = _link_instead_of_copy, sync=sync, thread_num=thread_num or 3,is_await=is_await)

## kaggle public API

**不能使用%cd这种会改变当前工作目录的命令，会导致和其他线程冲突**

---

In [None]:
# 安装kaggle的api token文件
def initKaggleConfig():
    if Path('~/.kaggle/kaggle.json').exists():
        return True
    if Path(kaggleApiTokenFile).exists():
        run(f'''mkdir -p ~/.kaggle/''')
        run('cp '+kaggleApiTokenFile+' ~/.kaggle/kaggle.json')
        run(f'''chmod 600 ~/.kaggle/kaggle.json''')
        return True
    print('缺少kaggle的apiToken文件，访问：https://www.kaggle.com/你的kaggle用户名/account 获取')
    return False

def getUserName():
    if not initKaggleConfig(): return
    import kaggle
    return kaggle.KaggleApi().read_config_file()['username']

def createOrUpdateDataSet(path:str,datasetName:str):
    if not initKaggleConfig(): return
    print('创建或更新数据集 '+datasetName)
    import kaggle
    run(f'mkdir -p {_install_path}/kaggle_cache')
    run(f'rm -rf {_install_path}/kaggle_cache/*')
    datasetDirPath = _install_path+'/kaggle_cache/'+datasetName
    run('mkdir -p '+datasetDirPath)
    run('cp -f '+path+' '+datasetDirPath+'/')
    username = getUserName()
    print("kaggle username:"+username)
    datasetPath = username+'/'+datasetName
    datasetList = kaggle.api.dataset_list(mine=True,search=datasetPath)
    print(datasetList)
    if len(datasetList) == 0 or datasetPath not in [str(d) for d in datasetList]: # 创建 create
        run('kaggle datasets init -p' + datasetDirPath)
        metadataFile = datasetDirPath+'/dataset-metadata.json'
        run('sed -i s/INSERT_TITLE_HERE/'+ datasetName + '/g ' + metadataFile)
        run('sed -i s/INSERT_SLUG_HERE/'+ datasetName + '/g ' + metadataFile)
        run('cat '+metadataFile)
        run('kaggle datasets create -p '+datasetDirPath)
        print('create database done')
    else:
        kaggle.api.dataset_metadata(datasetPath,datasetDirPath)
        kaggle.api.dataset_create_version(datasetDirPath, 'auto update',dir_mode='zip')
        print('upload database done')

def downloadDatasetFiles(datasetName:str,outputPath:str):
    if not initKaggleConfig(): return
    print('下载数据集文件 '+datasetName)
    import kaggle
    username = getUserName()
    datasetPath = username+'/'+datasetName
    datasetList = kaggle.api.dataset_list(mine=True,search=datasetPath)
    if datasetPath not in [str(d) for d in datasetList]:
        return False
    run('mkdir -p '+outputPath)
    kaggle.api.dataset_download_files(datasetPath,path=outputPath,unzip=True)
    return True



## 同步文件夹到 huggingface

---

In [None]:
# 文件夹与 huggingface 同步
if _huggingface_token and _huggingface_repo:
    if not is_installed('watchdog'):
        requirements.append('watchdog')
    if not is_installed('huggingface_hub'):
        requirements.append('huggingface_hub')
    else:
        try:
            from huggingface_hub  import HfApi,login,snapshot_download
        except:
            requirements.append('huggingface_hub')

huggingface_is_init = False

def init_huggingface():
    if not _huggingface_token:
        return False

    global huggingface_headers
    global huggingface_is_init
    
    from huggingface_hub  import login
    token = replace_path(_huggingface_token)
    if not _huggingface_token.startswith('hf_') and Path(token).exists():
        with open(token,encoding = "utf-8") as nkfile:
            token = nkfile.readline()
    if not token.startswith('hf_'):
        print('huggingface token 不正确，请将 token 或 仅存放token 的txt文件路径填入 _huggingface_token 配置')
        return False
    login(token,add_to_git_credential=True)
    huggingface_headers = {'Authorization': 'Bearer '+token}
    print('huggingface token 已经加载，可以下载私有仓库或文件')
    
    if not _huggingface_repo:
        print('huggingface 同步收藏图片功能不会启动，可增加配置项 huggingface_token = "token" 和 huggingface_repo = “仓库id” 后启用 huggingface 同步收藏图片功能')
        return False
    huggingface_is_init = True
    return True


def download__huggingface_repo(repo_id:str,dist_directory:str=None,repo_type='dataset',callback=None):
    if not huggingface_is_init:
        print('huggingface 相关功能未初始化 请调用 init_huggingface() 初始化')
        
    if not dist_directory:
        dist_directory = f'{_install_path}/{_ui_dir_name}/log'
    
    from huggingface_hub  import HfApi,login,snapshot_download
    
    api = HfApi()
    
    print('下载收藏的图片')
    if not Path(f'{_install_path}/cache/huggingface/huggingface_repo').exists():
        mkdirs(f'{_install_path}/cache/huggingface')
        repo_path = ''
        if repo_type == 'dataset':
            repo_path = 'datasets'
        if repo_type == 'space':
            repo_path = 'spaces'
        if repo_path:
            run(f'git clone https://huggingface.co/{repo_path}/{repo_id} huggingface_repo',cwd=f'{_install_path}/cache/huggingface')
        else:
            run(f'git clone https://huggingface.co/{repo_id} huggingface_repo',cwd=f'{_install_path}/cache/huggingface')

    run(f'cp -r -f -n -s {_install_path}/cache/huggingface/huggingface_repo/* {dist_directory}')
#     snapshot_download(repo_id = repo_id, local_dir = dist_directory, local_dir_use_symlinks = "auto", token=True, repo_type=repo_type )
    if callback:
        callback()

def start_sync_log_to_huggingface(repo_id:str,directory_to_watch:str=None,repo_type='dataset'):
    if not huggingface_is_init:
        print('huggingface 相关功能未初始化 请调用 init_huggingface() 初始化')
    
    from watchdog.observers import Observer
    from watchdog.events import FileSystemEventHandler
    from huggingface_hub  import HfApi,login,snapshot_download
    
    # 配置监视的目录和 Hugging Face 仓库信息
    class FileChangeHandler(FileSystemEventHandler):
        def __init__(self, api, repo_id, repo_type,directory_to_watch):
            self.api = api
            self.repo_id = repo_id
            self.repo_type = repo_type
            self.directory_to_watch = directory_to_watch
        def on_created(self, event):
            if not event.is_directory:
                # 上传新文件到 Hugging Face 仓库
                file_path = event.src_path
                file_name:str = os.path.basename(file_path)
                print(file_name)
                if file_name[file_name.rindex('.'):] not in ['.png','.jpg','.txt','.webp','.jpeg']: return
                print(file_name,'>>','huggingface')
                try:
                    self.api.upload_file(
                        path_or_fileobj=file_path,
                        path_in_repo=file_path.replace(self.directory_to_watch,''),
                        repo_id=self.repo_id,
                        repo_type=self.repo_type,
                    )
                except IOError as error:
                    print(error)

        def on_deleted(self, event):
            if not event.is_directory:
                # 从 Hugging Face 仓库删除文件
                file_path = event.src_path
                file_name = os.path.basename(file_path)
                if file_name[file_name.rindex('.'):] not in ['.png','.jpg','.txt','.webp','.jpeg']: return
                try:
                    self.api.delete_file(
                        path_in_repo=file_path.replace(self.directory_to_watch,''),
                        repo_id=self.repo_id,
                        repo_type=self.repo_type,
                        )
                except IOError as error:
                    print(error)

        def on_modified(self, event):
            if not event.is_directory:
                # 更新 Hugging Face 仓库中的文件
                file_path = event.src_path
                file_name = os.path.basename(file_path)
                if file_name[file_name.rindex('.'):] not in ['.png','.jpg','.txt','.webp','.jpeg']: return
                try:
                    self.api.upload_file(
                        path_or_fileobj=file_path,
                        path_in_repo=file_path.replace(self.directory_to_watch,''),
                        repo_id=self.repo_id,
                        repo_type=self.repo_type,
                    )
                except IOError as error:
                    print(error)

        def on_moved(self, event):
            if not event.is_directory:
                file_path = event.dest_path
                file_name = os.path.basename(file_path)
                if file_name[file_name.rindex('.'):] not in ['.png','.jpg','.txt','.webp','.jpeg']: return
                if event.dest_path.startswith(self.directory_to_watch):
                    try:
                        self.api.upload_file(
                            path_or_fileobj=file_path,
                            path_in_repo=file_path.replace(self.directory_to_watch,''),
                            repo_id=self.repo_id,
                            repo_type=self.repo_type,
                        )
                    except IOError as error:
                        print(error)

    api = HfApi()
    
    if not directory_to_watch:
        directory_to_watch = f'{_install_path}/{_ui_dir_name}/log'
    # 创建观察者对象并注册文件变化处理程序
    event_handler = FileChangeHandler(api,repo_id,repo_type,directory_to_watch)
    observer = Observer()
    observer.schedule(event_handler, directory_to_watch, recursive=True)

    # 启动观察者
    observer.name = "solo_directory_to_watch"
    print(f'启动收藏图片文件夹监听，并自动同步到 huggingface {repo_type} : {repo_id}')
    observer.start()

## 工具函数
**不能使用%cd这种会改变当前工作目录的命令，会导致和其他线程冲突**

---

In [None]:
def echoToFile(content:str,path:str):
    if path.find('/') >= 0:
        _path = '/'.join(path.split('/')[:-1])
        run(f'''mkdir -p {_path}''')
    with open(path,'w') as sh:
        sh.write(content)

def zipPath(path:str,zipName:str,format='tar'):
    if path.startswith('$install_path'):
        path = path.replace('$install_path',_install_path)
    if path.startswith('$output_path'):
        path = path.replace('$install_path',_output_path)
    if not path.startswith('/'):
        path = f'{_install_path}/{_ui_dir_name}/{path}'
    if Path(path).exists():
        if 'tar' == format:
            run(f'tar -cf {_output_path}/'+ zipName +'.tar -C '+ path +' . ')
        elif 'gz' == format:
            run(f'tar -czf {_output_path}/'+ zipName +'.tar.gz -C '+ path +' . ')
        return
    print('指定的目录不存在：'+path)

# 检查网络
def check_service(host, port):
    try:
        socket.create_connection((host, port), timeout=5)
        return True
    except socket.error:
        return False

## 内网穿透

---

In [None]:
def printUrl(url):
    print(f'访问地址：{url}')
    for key in sorted(_proxy_path.keys(), key=len)[::-1]:
        print(f'本地服务：{_proxy_path[key]}  访问地址：f{url}{key}')
# ngrok
def startNgrok(ngrokToken:str,ngrokLocalPort:int):
    if not is_installed('pyngrok'):
        %pip install pyngrok
    from pyngrok import conf, ngrok
    try:
        conf.get_default().auth_token = ngrokToken
        conf.get_default().monitor_thread = False
        ssh_tunnels = ngrok.get_tunnels(conf.get_default())
        url = ''
        if len(ssh_tunnels) == 0:
            ssh_tunnel = ngrok.connect(ngrokLocalPort)
            url = ssh_tunnel.public_url
            print('ngrok 访问地址：'+ssh_tunnel.public_url)
        else:
            print('ngrok 访问地址：'+ssh_tunnels[0].public_url)
            url = ssh_tunnels[0].public_url
        printUrl(url)
        def auto_request_ngrok():
            if url:
                while(_runing):
                    time.sleep(60*1)
                    res = requests.get(url+'/sdapi/v1/samplers',headers={"ngrok-skip-browser-warning" : "1"})
                    # print('自动调用ngrok链接以保存链接不会断开',res.status_code)

        threading.Thread(target = auto_request_ngrok,daemon=True,name='solo_auto_request_ngrok').start()
    except:
        print('启动ngrok出错')
        
def startFrpc(name,configFile):
    if not Path(f'{_install_path}/frpc/frpc').exists():
        installFrpExe()
    echoToFile(f'''
cd {_install_path}/frpc/
{_install_path}/frpc/frpc {configFile}
''',f'{_install_path}/frpc/start.sh')
    get_ipython().system(f'''bash {_install_path}/frpc/start.sh''')
        
def installFrpExe():
    if _useFrpc:
        print('安装frpc')
        run(f'mkdir -p {_install_path}/frpc')
        if Path(frpcExePath).exists():
            run(f'cp -f -n {frpcExePath} {_install_path}/frpc/frpc')
        else:
            run(f'wget "https://huggingface.co/datasets/ACCA225/Frp/resolve/main/frpc" -O {_install_path}/frpc/frpc')
        
        for ssl in frpcSSLFFlies:
            if Path(ssl).exists():
                run(f'cp -f -n {ssl}/* {_install_path}/frpc/')
        run(f'chmod +x {_install_path}/frpc/frpc')
        run(f'{_install_path}/frpc/frpc -v')

def startProxy():
    if _useNgrok:
        startNgrok(ngrokToken,_server_port)
    if _useFrpc:
        startFrpc('frpc_proxy',frpcStartArg)

## NGINX 反向代理

---

In [None]:

# nginx 反向代理配置文件
def localProxy():
    _proxy_path[_sub_path[0]] = f'http://127.0.0.1:{_server_port+1}/'
    _proxy_path[_sub_path[1]] = f'http://127.0.0.1:{_server_port+2}/'
    
    def getProxyLocation(subPath:str, localServer:str):
        return '''
    location '''+ subPath +'''
    {
        proxy_pass '''+ localServer +''';
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header REMOTE-HOST $remote_addr;
        proxy_set_header   Upgrade $http_upgrade;
        proxy_set_header   Connection upgrade;
        proxy_http_version 1.1;
        proxy_connect_timeout 10m;
        proxy_read_timeout 10m;
    }
    
    '''
    
    conf = '''
server
{
    listen '''+str(_server_port)+''';
    listen [::]:'''+str(_server_port)+''';
    server_name 127.0.0.1 localhost 0.0.0.0 "";
    
    if ($request_method = OPTIONS) {
        return 200;
    }
    fastcgi_send_timeout                 10m;
    fastcgi_read_timeout                 10m;
    fastcgi_connect_timeout              10m;
    
    '''+ ''.join([getProxyLocation(key,_proxy_path[key]) for key in sorted(_proxy_path.keys(), key=len)[::-1]]) +'''
}
'''
    echoToFile(conf,'/etc/nginx/conf.d/proxy_nginx.conf')
    if not check_service('localhost',_server_port):
        run(f'''nginx -c /etc/nginx/nginx.conf''')
    run(f'''nginx -s reload''')

## 线程清理工具

---

清理线程名以 solo_ 开头的所有线程

In [None]:
import inspect
import ctypes

def _async_raise(tid, exctype):
    """raises the exception, performs cleanup if needed"""
    tid = ctypes.c_long(tid)
    if not inspect.isclass(exctype):
        exctype = type(exctype)
    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
    if res == 0:
        raise ValueError("invalid thread id")
    elif res != 1:
        # """if it returns a number greater than one, you're in trouble,
        # and you should call it again with exc=NULL to revert the effect"""
        ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
        raise SystemError("PyThreadState_SetAsyncExc failed")

def stop_thread(thread):
    _async_raise(thread.ident, SystemExit)

def stop_solo_threads():
    global _runing
    _runing = False
    # 获取当前所有活动的线程
    threads = threading.enumerate()
    # 关闭之前创建的子线程
    for thread in threads:
        if thread.name.startswith('solo_'):
            print(f'结束线程：{thread.name}')
            try:
                stop_thread(thread)
            except socket.error:
                print(f'结束线程：{thread.name} 执行失败')

# webui 安装和配置函数
---

In [None]:
envInstalled=False
quickStart = True
#安装
def install():
    print('安装')
    os.chdir(f'''{_install_path}''')
    run(f'''git lfs install''')
    run(f'''git config --global credential.helper store''')
    for requirement in requirements:
        run(f'pip install {requirement}')
    if _reLoad:
        run(f'''rm -rf {_install_path}/{_ui_dir_name}''')
    if Path(f"{_ui_dir_name}").exists():
        os.chdir(f'''{_install_path}/{_ui_dir_name}/''')
        run(f'''git checkout .''')
        run(f'''git pull''')
    else:
        run(f'''git clone --recursive {_sd_git_repo} {_ui_dir_name}''')
    os.chdir(f'''{_install_path}/{_ui_dir_name}''')
    print('安装 完成')

# 链接输出目录
def link_dir():
    print('链接输出目录')
    # 链接图片输出目录
    run(f'''mkdir -p {_output_path}/outputs''')
    run(f'''rm -rf {_install_path}/{_ui_dir_name}/outputs''')
    run(f'''ln -s -r {_output_path}/outputs {_install_path}/{_ui_dir_name}/''')
     # 输出收藏目录
    run(f'''mkdir -p {_output_path}/log''')
    run(f'''rm -rf {_install_path}/{_ui_dir_name}/log''')
    run(f'''ln -s -r {_output_path}/log {_install_path}/{_ui_dir_name}/''')
    # 链接训练输出目录 文件夹链接会导致功能不能用
    run(f'''rm -rf {_install_path}/{_ui_dir_name}/textual_inversion''')
    run(f'''mkdir -p {_output_path}/textual_inversion/''')
    run(f'''ln -s -r {_output_path}/textual_inversion {_install_path}/{_ui_dir_name}/''')
    print('链接输出目录 完成') 

def install_optimizing():
    run('sudo apt update -y')
    run('sudo apt install nginx -y')
    
#安装依赖
def install_dependencies():
    print('安装需要的python环境')
    import venv
    global envInstalled
    global venvPath
    if Path(f'{_install_path}/{_ui_dir_name}/venv').exists():
        print('跳过安装python环境')
        envInstalled = True
        return
    
#     run('add-apt-repository ppa:deadsnakes/ppa -y')
#     run('apt update')
#     run('apt install python3.10 -y')
#     run('python3.10 -m venv venv',cwd=f'{_install_path}/{_ui_dir_name}')
    if quickStart:
        if not Path(venvPath).exists():
            mkdirs(f'{_install_path}/venv_cache',True)
            if not Path(f'{_install_path}/venv_cache/venv.tar.bak').exists():
                print('下载 venv.zip')
                download_file('https://huggingface.co/viyi/sdwui/resolve/main/venv.zip','venv.zip',f'{_install_path}/venv_cache')
            run(f'''unzip {_install_path}/venv_cache/venv.zip -d {_install_path}/venv_cache''')
            venvPath = f'{_install_path}/venv_cache/venv.tar.bak'
            run(f'''rm -rf {_install_path}/venv_cache/venv.zip''')
        elif venvPath.endswith('.zip'):
            mkdirs(f'{_install_path}/venv_cache',True)
            run(f'''unzip {venvPath} -d {_install_path}/venv_cache''')
            venvPath = f'{_install_path}/venv_cache/venv.tar.bak'
        print('解压环境')
        mkdirs(f'{_install_path}/{_ui_dir_name}/venv')
#         run('python3.10 -m venv venv',cwd=f'{_install_path}/{_ui_dir_name}')
        run(f'tar -xf {venvPath} -C ./venv',cwd=f'{_install_path}/{_ui_dir_name}')
    run(f'rm -f {_install_path}/{_ui_dir_name}/venv/bin/pip*')
    run(f'rm -f {_install_path}/{_ui_dir_name}/venv/bin/python*')
    venv.create(f'{_install_path}/{_ui_dir_name}/venv')
    if not Path(f'{_install_path}/{_ui_dir_name}/venv/bin/pip').exists():
        run('curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py')
        run(f'{_install_path}/{_ui_dir_name}/venv/bin/python3 get-pip.py')

    get_ipython().system(f'''{_install_path}/{_ui_dir_name}/venv/bin/python3 -V''')
    get_ipython().system(f'''{_install_path}/{_ui_dir_name}/venv/bin/python3 -m pip -V''')

    envInstalled = True
    print('安装需要的python环境 完成')
      
# 个性化配置 
def use_config():
    print('使用自定义配置 包括tag翻译 \n')
    run(f'''mkdir -p {_install_path}/temp''')
    run(f'git clone {_sd_config_git_repu} sd-configs',cwd=f'{_install_path}/temp')
    run(f'cp -r -f -n {_install_path}/temp/sd-configs/dist/* {_install_path}/{_ui_dir_name}')
    if not Path(_ui_config_file).exists(): # ui配置文件
        run(f'''mkdir -p {_ui_config_file[:_ui_config_file.rfind('/')]}''')
        run(f'cp -f -n {_install_path}/{_ui_dir_name}/ui-config.json {_ui_config_file}')
    if not Path(_setting_file).exists(): # 设置配置文件
        run(f'''mkdir -p {_setting_file[:_setting_file.rfind('/')]}''')
        run(f'cp -f -n {_install_path}/{_ui_dir_name}/config.json {_setting_file}')

def copy_last_log_to_images():
    if not Path(f'{_install_path}/{_ui_dir_name}/log/images').exists(): mkdirs(f'{_install_path}/{_ui_dir_name}/log/images')
    print('复制编号最大的一张收藏图到输出目录，用于保持编号，否则会出现收藏的图片被覆盖的情况')
    img_list = os.listdir(f'{_install_path}/{_ui_dir_name}/log/images')
    last_img_path = ''
    last_img_num = 0
    for img in img_list:
        if re.findall(r'^\d+-',str(img)):
            num = int(re.findall(r'^\d+-',str(img))[0][:-1])
            if num > last_img_num:
                last_img_path = img
                last_img_num = num
    print(f'{_install_path}/{_ui_dir_name}/log/images/{last_img_path} {_install_path}/{_ui_dir_name}/outputs/txt2img-images')
    run(f'''mkdir -p {_install_path}/{_ui_dir_name}/outputs/txt2img-images''')
    run(f'''cp -f {_install_path}/{_ui_dir_name}/log/images/{last_img_path} {_install_path}/{_ui_dir_name}/outputs/txt2img-images/''')
    
    print(f'{_install_path}/{_ui_dir_name}/log/images/{last_img_path} {_install_path}/{_ui_dir_name}/outputs/img2img-images')
    run(f'''mkdir -p {_install_path}/{_ui_dir_name}/outputs/img2img-images''')
    run(f'''cp -f {_install_path}/{_ui_dir_name}/log/images/{last_img_path} {_install_path}/{_ui_dir_name}/outputs/img2img-images/''')
    
def start_webui(i):
    # 只要不爆内存，其他方式关闭后会再次重启 访问地址会发生变化
    print(i,'--port',str(_server_port+1+i))
    if i>0:
        print(f'使用第{i+1}张显卡启动第{i+1}个服务，通过frpc或nrgok地址后加/{i}/进行访问')

    restart_times = 0
    last_restart_time = time.time()
    while _runing:
        os.chdir(f'{_install_path}/{_ui_dir_name}')
        get_ipython().system(f'''venv/bin/python3 launch.py --device-id={i} --port {str(_server_port+1+i)} --subpath={_sub_path[i]}''')
        print('10秒后重启服务')
        if time.time() - last_restart_time < 60:
            restart_times = restart_times + 1
        else:
            restart_times = 0
        last_restart_time = time.time()
        if restart_times >3 :
            # 如果180秒内重启了3此，将不再自动重启
            break
        time.sleep(10)
    
# 启动
def start():
    print('启动')
    os.chdir(f'''{_install_path}/{_ui_dir_name}''')
    args = ''
    if _ui_config_file is not None and _ui_config_file != '' and Path(_ui_config_file).exists(): # ui配置文件
        args += ' --ui-config-file=' + _ui_config_file
    if _setting_file is not None and _setting_file != '' and Path(_setting_file).exists(): # 设置配置文件
        args += ' --ui-settings-file=' + _setting_file
    args += ' ' + otherArgs
    os.environ['COMMANDLINE_ARGS']=args
    run(f'''echo COMMANDLINE_ARGS=$COMMANDLINE_ARGS''')
    os.environ['REQS_FILE']='requirements.txt'

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        for i in range(torch.cuda.device_count() if _multi_case else 1):
            executor.submit(start_webui,i)
            while _runing and not check_service('localhost',str(_server_port+1+i)): # 当当前服务启动完成才允许退出此次循环
                time.sleep(5)
            if not _runing: break
            time.sleep(10)

# 入口函数
---

In [None]:

# 启动非webui相关的的内容，加快启动速度
def main():
    global envInstalled
    global huggingface_is_init
    global _runing
    stop_solo_threads()
    time.sleep(15)
    _runing = True
    startTicks = time.time()
    isInstall = True if os.getenv('IsInstall','False') == 'True' else False
    threading.Thread(target = startProxy, daemon=True, name='solo_startProxy').start()
    if isInstall is False or _reLoad: 
        print('启动 安装运行环境')
        install()
        link_dir()
        init_huggingface()
        threading.Thread(target = install_dependencies,daemon=True,name='solo_install_dependencies').start()
        threading.Thread(target = install_optimizing,daemon=True,name='solo_install_optimizing').start()
        link_or_download_flie(replace_path(_async_downloading), _link_instead_of_copy=_link_instead_of_copy,
                          base_path=f'{_install_path}/{_ui_dir_name}')
        if huggingface_is_init:
            threading.Thread(target = download__huggingface_repo,daemon=True,
                                 args=([_huggingface_repo]),
                                 kwargs={"callback":copy_last_log_to_images},
                                 name='solo_download__huggingface_repo').start()
        
        link_or_download_flie(replace_path(_before_downloading), _link_instead_of_copy=_link_instead_of_copy,
                          base_path=f'{_install_path}/{_ui_dir_name}',is_await=True,sync=True)
        t = 0
        while _runing and not envInstalled:
            if t%10==0:
                print('等待python环境安装...')
            t = t+1
            time.sleep(1)
        use_config()
        os.environ['IsInstall'] = 'True'
    else:
        envInstalled = True
    localProxy()
    link_or_download_flie(replace_path(_before_start_sync_downloading), _link_instead_of_copy=_link_instead_of_copy,
                          base_path=f'{_install_path}/{_ui_dir_name}',sync=True)
    if init_huggingface():
        start_sync_log_to_huggingface(_huggingface_repo)
    ticks = time.time()
    _on_before_start()
    print("加载耗时:",(ticks - startTicks),"秒")
    start()


# 执行区域
---

In [None]:
# 启动
# _reLoad = True
# hidden_console_info = False
# run_by_none_device = True
# show_shell_info = True

print(f'当前sd的安装路径是：{_install_path}/{_ui_dir_name}')
print(f'当前图片保存路径是：{_output_path}')
print(f'当前数据集路径是：{_input_path}')

print(update_desc)

if _skip_start:
    print('已跳过自动启动，可手动执行 main() 进行启动。')
    print('''推荐的启动代码：
try:
    check_gpu() # 检查是否存在gpu
    main()
except KeyboardInterrupt:
    stop_solo_threads() # 中断后自动停止后台线程 （有部分功能在后台线程中运行）
    ''')
else:
    try:
        check_gpu()
        main()
    except KeyboardInterrupt:
        stop_solo_threads()