File size: 1,199 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import glob
import os.path
from typing import List, Tuple


def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]:
    ckpts_dir = os.path.join(workdir, 'ckpts')
    pt_steps = []
    pt_name = None
    for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')):
        name = os.path.basename(pt)
        segs = os.path.splitext(name)[0].split('-')
        if pt_name is None:
            pt_name = '-'.join(segs[:-1])
        else:
            if pt_name != '-'.join(segs[:-1]):
                raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.')
        pt_steps.append(int(segs[-1]))

    unet_steps = []
    for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')):
        name = os.path.basename(unet)
        segs = os.path.splitext(name)[0].split('-')
        unet_steps.append(int(segs[-1]))

    text_encoder_steps = []
    for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')):
        name = os.path.basename(text_encoder)
        segs = os.path.splitext(name)[0].split('-')
        text_encoder_steps.append(int(segs[-1]))

    return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps))