File size: 2,349 Bytes
2f4febc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import subprocess
import yaml
import os
from .bucketeer import Bucketeer

class MultiFilter():
    def __init__(self, rules, default=False):
        self.rules = rules
        self.default = default

    def __call__(self, x):
        try:
            x_json = x['json']
            if isinstance(x_json, bytes):
                x_json = json.loads(x_json) 
            validations = []
            for k, r in self.rules.items():
                if isinstance(k, tuple):
                    v = r(*[x_json[kv] for kv in k])
                else:
                    v = r(x_json[k])
                validations.append(v)
            return all(validations)
        except Exception:
            return False

class MultiGetter():
    def __init__(self, rules):
        self.rules = rules

    def __call__(self, x_json):
        if isinstance(x_json, bytes):
            x_json = json.loads(x_json) 
        outputs = []
        for k, r in self.rules.items():
            if isinstance(k, tuple):
                v = r(*[x_json[kv] for kv in k])
            else:
                v = r(x_json[k])
            outputs.append(v)
        if len(outputs) == 1:
            outputs = outputs[0]
        return outputs

def setup_webdataset_path(paths, cache_path=None):
    if cache_path is None or not os.path.exists(cache_path):
        tar_paths = []
        if isinstance(paths, str):
            paths = [paths]
        for path in paths:
            if path.strip().endswith(".tar"):
                # Avoid looking up s3 if we already have a tar file
                tar_paths.append(path)
                continue
            bucket = "/".join(path.split("/")[:3])
            result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
            files = result.stdout.decode('utf-8').split()
            files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
            tar_paths += files

        with open(cache_path, 'w', encoding='utf-8') as outfile:
            yaml.dump(tar_paths, outfile, default_flow_style=False)
    else:
        with open(cache_path, 'r', encoding='utf-8') as file:
            tar_paths = yaml.safe_load(file)

    tar_paths_str = ",".join([f"{p}" for p in tar_paths])
    return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"