Spaces:
Runtime error
Runtime error
File size: 1,199 Bytes
69a6cef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import glob
import os.path
from typing import List, Tuple
def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]:
ckpts_dir = os.path.join(workdir, 'ckpts')
pt_steps = []
pt_name = None
for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')):
name = os.path.basename(pt)
segs = os.path.splitext(name)[0].split('-')
if pt_name is None:
pt_name = '-'.join(segs[:-1])
else:
if pt_name != '-'.join(segs[:-1]):
raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.')
pt_steps.append(int(segs[-1]))
unet_steps = []
for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')):
name = os.path.basename(unet)
segs = os.path.splitext(name)[0].split('-')
unet_steps.append(int(segs[-1]))
text_encoder_steps = []
for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')):
name = os.path.basename(text_encoder)
segs = os.path.splitext(name)[0].split('-')
text_encoder_steps.append(int(segs[-1]))
return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps))
|