Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
import base64 | |
import json | |
import time | |
import logging | |
import folder_paths | |
import glob | |
import comfy.utils | |
from aiohttp import web | |
from PIL import Image | |
from io import BytesIO | |
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types | |
class ModelFileManager: | |
def __init__(self) -> None: | |
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} | |
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None: | |
return self.cache.get(key, default) | |
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]): | |
self.cache[key] = value | |
def clear_cache(self): | |
self.cache.clear() | |
def add_routes(self, routes): | |
# NOTE: This is an experiment to replace `/models` | |
async def get_model_folders(request): | |
model_types = list(folder_paths.folder_names_and_paths.keys()) | |
folder_black_list = ["configs", "custom_nodes"] | |
output_folders: list[dict] = [] | |
for folder in model_types: | |
if folder in folder_black_list: | |
continue | |
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) | |
return web.json_response(output_folders) | |
# NOTE: This is an experiment to replace `/models/{folder}` | |
async def get_all_models(request): | |
folder = request.match_info.get("folder", None) | |
if not folder in folder_paths.folder_names_and_paths: | |
return web.Response(status=404) | |
files = self.get_model_file_list(folder) | |
return web.json_response(files) | |
async def get_model_preview(request): | |
folder_name = request.match_info.get("folder", None) | |
path_index = int(request.match_info.get("path_index", None)) | |
filename = request.match_info.get("filename", None) | |
if not folder_name in folder_paths.folder_names_and_paths: | |
return web.Response(status=404) | |
folders = folder_paths.folder_names_and_paths[folder_name] | |
folder = folders[0][path_index] | |
full_filename = os.path.join(folder, filename) | |
previews = self.get_model_previews(full_filename) | |
default_preview = previews[0] if len(previews) > 0 else None | |
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): | |
return web.Response(status=404) | |
try: | |
with Image.open(default_preview) as img: | |
img_bytes = BytesIO() | |
img.save(img_bytes, format="WEBP") | |
img_bytes.seek(0) | |
return web.Response(body=img_bytes.getvalue(), content_type="image/webp") | |
except: | |
return web.Response(status=404) | |
def get_model_file_list(self, folder_name: str): | |
folder_name = map_legacy(folder_name) | |
folders = folder_paths.folder_names_and_paths[folder_name] | |
output_list: list[dict] = [] | |
for index, folder in enumerate(folders[0]): | |
if not os.path.isdir(folder): | |
continue | |
out = self.cache_model_file_list_(folder) | |
if out is None: | |
out = self.recursive_search_models_(folder, index) | |
self.set_cache(folder, out) | |
output_list.extend(out[0]) | |
return output_list | |
def cache_model_file_list_(self, folder: str): | |
model_file_list_cache = self.get_cache(folder) | |
if model_file_list_cache is None: | |
return None | |
if not os.path.isdir(folder): | |
return None | |
if os.path.getmtime(folder) != model_file_list_cache[1]: | |
return None | |
for x in model_file_list_cache[1]: | |
time_modified = model_file_list_cache[1][x] | |
folder = x | |
if os.path.getmtime(folder) != time_modified: | |
return None | |
return model_file_list_cache | |
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]: | |
if not os.path.isdir(directory): | |
return [], {}, time.perf_counter() | |
excluded_dir_names = [".git"] | |
# TODO use settings | |
include_hidden_files = False | |
result: list[str] = [] | |
dirs: dict[str, float] = {} | |
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): | |
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] | |
if not include_hidden_files: | |
subdirs[:] = [d for d in subdirs if not d.startswith(".")] | |
filenames = [f for f in filenames if not f.startswith(".")] | |
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions) | |
for file_name in filenames: | |
try: | |
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) | |
result.append(relative_path) | |
except: | |
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") | |
continue | |
for d in subdirs: | |
path: str = os.path.join(dirpath, d) | |
try: | |
dirs[path] = os.path.getmtime(path) | |
except FileNotFoundError: | |
logging.warning(f"Warning: Unable to access {path}. Skipping this path.") | |
continue | |
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() | |
def get_model_previews(self, filepath: str) -> list[str | BytesIO]: | |
dirname = os.path.dirname(filepath) | |
if not os.path.exists(dirname): | |
return [] | |
basename = os.path.splitext(filepath)[0] | |
match_files = glob.glob(f"{basename}.*", recursive=False) | |
image_files = filter_files_content_types(match_files, "image") | |
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) | |
safetensors_metadata = {} | |
result: list[str | BytesIO] = [] | |
for filename in image_files: | |
_basename = os.path.splitext(filename)[0] | |
if _basename == basename: | |
result.append(filename) | |
if _basename == f"{basename}.preview": | |
result.append(filename) | |
if safetensors_file: | |
safetensors_filepath = os.path.join(dirname, safetensors_file) | |
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) | |
if header: | |
safetensors_metadata = json.loads(header) | |
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) | |
if safetensors_images: | |
safetensors_images = json.loads(safetensors_images) | |
for image in safetensors_images: | |
result.append(BytesIO(base64.b64decode(image))) | |
return result | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.clear_cache() | |