image_generator / loader.py
nsfwalex's picture
update
89bc548
import os
import sys
from huggingface_hub import hf_hub_download
def load_script(file_str: str):
"""
Downloads a file from the Hugging Face Hub and ensures a symlink exists in the current directory.
Parameters:
- file_str (str): Path in the format 'repo_id/[subfolder]/filename', e.g., 'myorg/myrepo/mysubfolder/myscript.py'
Returns:
- str: The path to the downloaded file.
"""
try:
# Split the path by "/"
parts = file_str.strip().split("/")
if len(parts) < 2:
raise ValueError(
f"Invalid file specification '{file_str}'. "
f"Expected format: 'repo_id/[subfolder]/filename'"
)
# First two parts form the repo_id (e.g., 'myorg/myrepo')
repo_id = "/".join(parts[:2])
# Last part is the actual filename (e.g., 'myscript.py')
filename = parts[-1]
# Anything between the second and last parts is a subfolder path
subfolder = "/".join(parts[2:-1]) if len(parts) > 3 else None
# Retrieve HF token from environment
hf_token = os.getenv("HF_TOKEN", None)
if not hf_token:
print("Warning: 'HF_TOKEN' environment variable not set. Proceeding without authentication.")
# Download the file into current directory "."
file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
repo_type="space",
token=hf_token,
local_dir=".", # Download into the current directory
force_download=True,
)
print(f"Downloaded '{filename}' from '{repo_id}' to '{file_path}'")
# Absolute paths for comparison
current_dir = os.path.abspath(".")
downloaded_file_abs = os.path.abspath(file_path)
downloaded_dir_abs = os.path.dirname(downloaded_file_abs)
# If the file is not in the current directory, create a symlink
if downloaded_dir_abs != current_dir:
symlink_path = os.path.join(current_dir, filename)
# If symlink exists, remove it
if os.path.islink(symlink_path) or os.path.exists(symlink_path):
try:
os.remove(symlink_path)
print(f"Removed existing link or file: '{symlink_path}'")
except Exception as e:
print(f"Error removing existing link '{symlink_path}': {e}")
return file_path # Return the actual file path even if symlink fails
# Create a relative symlink
relative_target = os.path.relpath(downloaded_file_abs, current_dir)
try:
os.symlink(relative_target, symlink_path)
print(f"Created symlink: '{symlink_path}' -> '{relative_target}'")
except OSError as e:
print(f"Failed to create symlink for '{filename}': {e}")
# On Windows, creating symlinks may require admin privileges
# Alternatively, you can copy the file instead of linking
# Uncomment the following lines to copy the file if symlink fails
# import shutil
# try:
# shutil.copy2(downloaded_file_abs, symlink_path)
# print(f"Copied '{filename}' to '{symlink_path}'")
# except Exception as copy_e:
# print(f"Failed to copy file for '{filename}': {copy_e}")
return file_path
except Exception as e:
print(f"Error downloading the script '{file_str}': {e}")
return None
def load_scripts():
"""
Downloads and executes scripts based on a file list from the Hugging Face Hub.
Steps:
1. Retrieve the 'FILE_LIST' environment variable, which specifies the file list path.
2. Download the file list using `load_script()`.
3. Read each line from the downloaded file list, where each line specifies another file to download.
4. After downloading all files, execute the last downloaded file.
"""
file_list = os.getenv("FILE_LIST", "").strip()
if not file_list:
print("No 'FILE_LIST' environment variable set. Nothing to download.")
return
print(f"FILE_LIST: '{file_list}'")
# Step 1: Download the file list itself
file_list_path = load_script(file_list)
if not file_list_path or not os.path.exists(file_list_path):
print(f"Could not download or find file list: '{file_list_path}'")
return
# Step 2: Read each line in the downloaded file list
try:
with open(file_list_path, 'r') as f:
lines = [line.strip() for line in f if line.strip()]
print(f"Found {len(lines)} files to download from the file list.")
except Exception as e:
print(f"Error reading file list '{file_list_path}': {e}")
return
# Step 3: Download each file from the lines
downloaded_files = []
for idx, file_str in enumerate(lines, start=1):
print(f"Downloading file {idx}/{len(lines)}: '{file_str}'")
file_path = load_script(file_str)
if file_path:
downloaded_files.append(file_path)
# Step 4: Execute the last downloaded file
if downloaded_files:
last_file_path = downloaded_files[-1]
print(f"Executing the last downloaded script: '{last_file_path}'")
try:
with open(last_file_path, 'r') as f:
script_content = f.read()
exec(script_content, globals())
print(f"Successfully executed '{last_file_path}'")
except Exception as e:
print(f"Error executing the last downloaded script '{last_file_path}': {e}")
else:
print("No files were downloaded to execute.")
load_scripts()