Spaces:
Runtime error
Runtime error
import os | |
from huggingface_hub import snapshot_download, delete_repo, metadata_update | |
import uuid | |
import json | |
import yaml | |
import subprocess | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_DATASET = os.environ.get("DATA_PATH") | |
def download_dataset(hf_dataset_path: str): | |
random_id = str(uuid.uuid4()) | |
snapshot_download( | |
repo_id=hf_dataset_path, | |
token=HF_TOKEN, | |
local_dir=f"/tmp/{random_id}", | |
repo_type="dataset", | |
) | |
return f"/tmp/{random_id}" | |
def process_dataset(dataset_dir: str): | |
# dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt | |
# generate .txt files with the same name as the images with the prompt as the content | |
# remove metadata.jsonl | |
# return the path to the processed dataset | |
# check if config.yaml exists | |
if not os.path.exists(os.path.join(dataset_dir, "config.yaml")): | |
raise ValueError("config.yaml does not exist") | |
# check if metadata.jsonl exists | |
if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")): | |
metadata = [] | |
with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f: | |
for line in f: | |
if len(line.strip()) > 0: | |
metadata.append(json.loads(line)) | |
for item in metadata: | |
txt_path = os.path.join(dataset_dir, item["file_name"]) | |
txt_path = txt_path.rsplit(".", 1)[0] + ".txt" | |
with open(txt_path, "w") as f: | |
f.write(item["prompt"]) | |
# remove metadata.jsonl | |
os.remove(os.path.join(dataset_dir, "metadata.jsonl")) | |
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
config = yaml.safe_load(f) | |
# update config with new dataset | |
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir | |
with open(os.path.join(dataset_dir, "config.yaml"), "w") as f: | |
yaml.dump(config, f) | |
return dataset_dir | |
def run_training(hf_dataset_path: str): | |
dataset_dir = download_dataset(hf_dataset_path) | |
dataset_dir = process_dataset(dataset_dir) | |
# run training | |
commands = "git clone https://github.com/ostris/ai-toolkit.git ai-toolkit && cd ai-toolkit && git checkout bc693488eb3cf48ded8bc2af845059d80f4cf7d0 && git submodule update --init --recursive" | |
subprocess.run(commands, shell=True) | |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}" | |
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ) | |
return process, dataset_dir | |
if __name__ == "__main__": | |
process, dataset_dir = run_training(HF_DATASET) | |
process.wait() # Wait for the training process to finish | |
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
config = yaml.safe_load(f) | |
repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] | |
metadata = { | |
"tags": [ | |
"autotrain", | |
"spacerunner", | |
"text-to-image", | |
"flux", | |
"lora", | |
"diffusers", | |
"template:sd-lora", | |
] | |
} | |
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True) | |
delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True) | |