Spaces:
Runtime error
Runtime error
import glob | |
import os.path | |
from typing import List, Tuple | |
def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]: | |
ckpts_dir = os.path.join(workdir, 'ckpts') | |
pt_steps = [] | |
pt_name = None | |
for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')): | |
name = os.path.basename(pt) | |
segs = os.path.splitext(name)[0].split('-') | |
if pt_name is None: | |
pt_name = '-'.join(segs[:-1]) | |
else: | |
if pt_name != '-'.join(segs[:-1]): | |
raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.') | |
pt_steps.append(int(segs[-1])) | |
unet_steps = [] | |
for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')): | |
name = os.path.basename(unet) | |
segs = os.path.splitext(name)[0].split('-') | |
unet_steps.append(int(segs[-1])) | |
text_encoder_steps = [] | |
for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')): | |
name = os.path.basename(text_encoder) | |
segs = os.path.splitext(name)[0].split('-') | |
text_encoder_steps.append(int(segs[-1])) | |
return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps)) | |