Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import configparser | |
import dataclasses | |
import os | |
import threading | |
import re | |
import json | |
from modules import shared, errors, cache, scripts | |
from modules.gitpython_hack import Repo | |
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 | |
from modules_forge.config import always_disabled_extensions | |
extensions: list[Extension] = [] | |
extension_paths: dict[str, Extension] = {} | |
loaded_extensions: dict[str, Exception] = {} | |
os.makedirs(extensions_dir, exist_ok=True) | |
def active(): | |
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": | |
return [] | |
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra": | |
return [x for x in extensions if x.enabled and x.is_builtin] | |
else: | |
return [x for x in extensions if x.enabled] | |
class CallbackOrderInfo: | |
name: str | |
before: list | |
after: list | |
class ExtensionMetadata: | |
filename = "metadata.ini" | |
config: configparser.ConfigParser | |
canonical_name: str | |
requires: list | |
def __init__(self, path, canonical_name): | |
self.config = configparser.ConfigParser() | |
filepath = os.path.join(path, self.filename) | |
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is), | |
# so no need to check whether the file exists beforehand. | |
try: | |
self.config.read(filepath) | |
except Exception: | |
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) | |
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) | |
self.canonical_name = canonical_name.lower().strip() | |
self.requires = None | |
def get_script_requirements(self, field, section, extra_section=None): | |
"""reads a list of requirements from the config; field is the name of the field in the ini file, | |
like Requires or Before, and section is the name of the [section] in the ini file; additionally, | |
reads more requirements from [extra_section] if specified.""" | |
x = self.config.get(section, field, fallback='') | |
if extra_section: | |
x = x + ', ' + self.config.get(extra_section, field, fallback='') | |
listed_requirements = self.parse_list(x.lower()) | |
res = [] | |
for requirement in listed_requirements: | |
loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions) | |
relevant_requirement = next(loaded_requirements, requirement) | |
res.append(relevant_requirement) | |
return res | |
def parse_list(self, text): | |
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" | |
if not text: | |
return [] | |
# both "," and " " are accepted as separator | |
return [x for x in re.split(r"[,\s]+", text.strip()) if x] | |
def list_callback_order_instructions(self): | |
for section in self.config.sections(): | |
if not section.startswith("callbacks/"): | |
continue | |
callback_name = section[10:] | |
if not callback_name.startswith(self.canonical_name): | |
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}") | |
continue | |
before = self.parse_list(self.config.get(section, 'Before', fallback='')) | |
after = self.parse_list(self.config.get(section, 'After', fallback='')) | |
yield CallbackOrderInfo(callback_name, before, after) | |
class Extension: | |
lock = threading.Lock() | |
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] | |
metadata: ExtensionMetadata | |
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None): | |
self.name = name | |
self.path = path | |
self.enabled = enabled | |
self.status = '' | |
self.can_update = False | |
self.is_builtin = is_builtin | |
self.commit_hash = '' | |
self.commit_date = None | |
self.version = '' | |
self.branch = None | |
self.remote = None | |
self.have_info_from_repo = False | |
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) | |
self.canonical_name = metadata.canonical_name | |
self.is_forge_space = False | |
self.space_meta = None | |
if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')): | |
self.is_forge_space = True | |
self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8')) | |
def to_dict(self): | |
return {x: getattr(self, x) for x in self.cached_fields} | |
def from_dict(self, d): | |
for field in self.cached_fields: | |
setattr(self, field, d[field]) | |
def read_info_from_repo(self): | |
if self.is_builtin or self.have_info_from_repo: | |
return | |
def read_from_repo(): | |
with self.lock: | |
if self.have_info_from_repo: | |
return | |
self.do_read_info_from_repo() | |
return self.to_dict() | |
try: | |
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) | |
self.from_dict(d) | |
except FileNotFoundError: | |
pass | |
self.status = 'unknown' if self.status == '' else self.status | |
def do_read_info_from_repo(self): | |
repo = None | |
try: | |
if os.path.exists(os.path.join(self.path, ".git")): | |
repo = Repo(self.path) | |
except Exception: | |
errors.report(f"Error reading github repository info from {self.path}", exc_info=True) | |
if repo is None or repo.bare: | |
self.remote = None | |
else: | |
try: | |
self.remote = next(repo.remote().urls, None) | |
commit = repo.head.commit | |
self.commit_date = commit.committed_date | |
if repo.active_branch: | |
self.branch = repo.active_branch.name | |
self.commit_hash = commit.hexsha | |
self.version = self.commit_hash[:8] | |
except Exception: | |
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) | |
self.remote = None | |
self.have_info_from_repo = True | |
def list_files(self, subdir, extension): | |
dirpath = os.path.join(self.path, subdir) | |
if not os.path.isdir(dirpath): | |
return [] | |
res = [] | |
for filename in sorted(os.listdir(dirpath)): | |
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename))) | |
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] | |
return res | |
def check_updates(self): | |
repo = Repo(self.path) | |
branch_name = f'{repo.remote().name}/{self.branch}' | |
for fetch in repo.remote().fetch(dry_run=True): | |
if self.branch and fetch.name != branch_name: | |
continue | |
if fetch.flags != fetch.HEAD_UPTODATE: | |
self.can_update = True | |
self.status = "new commits" | |
return | |
try: | |
origin = repo.rev_parse(branch_name) | |
if repo.head.commit != origin: | |
self.can_update = True | |
self.status = "behind HEAD" | |
return | |
except Exception: | |
self.can_update = False | |
self.status = "unknown (remote error)" | |
return | |
self.can_update = False | |
self.status = "latest" | |
def fetch_and_reset_hard(self, commit=None): | |
repo = Repo(self.path) | |
if commit is None: | |
commit = f'{repo.remote().name}/{self.branch}' | |
# Fix: `error: Your local changes to the following files would be overwritten by merge`, | |
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. | |
repo.git.fetch(all=True) | |
repo.git.reset(commit, hard=True) | |
self.have_info_from_repo = False | |
def list_extensions(): | |
extensions.clear() | |
extension_paths.clear() | |
loaded_extensions.clear() | |
if shared.cmd_opts.disable_all_extensions: | |
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") | |
elif shared.opts.disable_all_extensions == "all": | |
print("*** \"Disable all extensions\" option was set, will not load any extensions ***") | |
elif shared.cmd_opts.disable_extra_extensions: | |
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") | |
elif shared.opts.disable_all_extensions == "extra": | |
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") | |
# scan through extensions directory and load metadata | |
for dirname in [extensions_builtin_dir, extensions_dir]: | |
if not os.path.isdir(dirname): | |
continue | |
for extension_dirname in sorted(os.listdir(dirname)): | |
path = os.path.join(dirname, extension_dirname) | |
if not os.path.isdir(path): | |
continue | |
canonical_name = extension_dirname | |
metadata = ExtensionMetadata(path, canonical_name) | |
# check for duplicated canonical names | |
already_loaded_extension = loaded_extensions.get(metadata.canonical_name) | |
if already_loaded_extension is not None: | |
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False) | |
continue | |
is_builtin = dirname == extensions_builtin_dir | |
disabled_extensions = shared.opts.disabled_extensions + always_disabled_extensions | |
extension = Extension( | |
name=extension_dirname, | |
path=path, | |
enabled=extension_dirname not in disabled_extensions, | |
is_builtin=is_builtin, | |
metadata=metadata | |
) | |
extensions.append(extension) | |
extension_paths[extension.path] = extension | |
loaded_extensions[canonical_name] = extension | |
for extension in extensions: | |
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension") | |
# check for requirements | |
for extension in extensions: | |
if not extension.enabled: | |
continue | |
for req in extension.metadata.requires: | |
required_extension = loaded_extensions.get(req) | |
if required_extension is None: | |
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) | |
continue | |
if not required_extension.enabled: | |
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) | |
continue | |
def find_extension(filename): | |
parentdir = os.path.dirname(os.path.realpath(filename)) | |
while parentdir != filename: | |
extension = extension_paths.get(parentdir) | |
if extension is not None: | |
return extension | |
filename = parentdir | |
parentdir = os.path.dirname(filename) | |
return None | |