devforfu
commited on
Commit
•
c1f3687
1
Parent(s):
4334bbd
Movie stills binary classifier
Browse files- metadata/movies_plus.jsonl +3 -0
- realfake/bin/download_s3.py +43 -18
- realfake/utils.py +5 -2
- submit_movie.sh +24 -0
metadata/movies_plus.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:750d4828d1d5051b390519fa1e964b45305a5ffae3d1ef50b783568452bc13fa
|
3 |
+
size 5992428
|
realfake/bin/download_s3.py
CHANGED
@@ -1,23 +1,34 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
import tarfile
|
3 |
from dataclasses import dataclass
|
4 |
from pathlib import Path
|
|
|
5 |
|
6 |
import boto3
|
7 |
from joblib import Parallel, delayed
|
8 |
|
9 |
-
from realfake.utils import get_user_name
|
10 |
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
|
14 |
-
start_idx, end_idx =
|
15 |
keys_range = list(range(start_idx, end_idx))
|
16 |
|
17 |
output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
jobs = get_jobs(keys_range, bucket, prefix, output_dir)
|
21 |
|
22 |
Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
|
23 |
|
@@ -29,7 +40,14 @@ class Job:
|
|
29 |
output_dir: Path
|
30 |
|
31 |
|
32 |
-
def get_jobs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
client = boto3.client("s3")
|
34 |
|
35 |
token, jobs = None, []
|
@@ -41,8 +59,10 @@ def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> li
|
|
41 |
|
42 |
for item in response.get("Contents"):
|
43 |
key = Path(item["Key"])
|
44 |
-
if key.suffix == ".tar" and int(key.stem) in keys_range:
|
45 |
jobs.append(Job(bucket, key, output_dir))
|
|
|
|
|
46 |
|
47 |
if not response["IsTruncated"]: break
|
48 |
token = response["NextContinuationToken"]
|
@@ -52,19 +72,24 @@ def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> li
|
|
52 |
|
53 |
def download_and_extract(job: Job) -> None:
|
54 |
client = boto3.client("s3")
|
55 |
-
|
56 |
|
57 |
print(f"{job.key}: downloading...")
|
58 |
-
client.download_file(job.bucket, str(job.key),
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
print(f"{job.key}: done!")
|
67 |
-
|
68 |
|
69 |
|
70 |
if __name__ == "__main__":
|
|
|
|
|
1 |
import tarfile
|
2 |
from dataclasses import dataclass
|
3 |
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
|
6 |
import boto3
|
7 |
from joblib import Parallel, delayed
|
8 |
|
9 |
+
from realfake.utils import get_user_name, inject_args, Args
|
10 |
|
11 |
|
12 |
+
class DownloadArgs(Args):
|
13 |
+
start_idx: int = 0
|
14 |
+
end_idx: int = 5247
|
15 |
+
metadata_only: bool = False
|
16 |
+
|
17 |
+
|
18 |
+
@inject_args
|
19 |
+
def main(args: DownloadArgs) -> None:
|
20 |
+
print(args)
|
21 |
bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
|
22 |
+
start_idx, end_idx = args.start_idx, args.end_idx
|
23 |
keys_range = list(range(start_idx, end_idx))
|
24 |
|
25 |
output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
|
26 |
+
if not args.metadata_only:
|
27 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
metadata_dir = output_dir.parent/f"{output_dir.name}.metadata"
|
29 |
+
metadata_dir.mkdir(parents=True, exist_ok=True)
|
30 |
|
31 |
+
jobs = get_jobs(keys_range, bucket, prefix, output_dir, metadata_dir, args.metadata_only)
|
32 |
|
33 |
Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
|
34 |
|
|
|
40 |
output_dir: Path
|
41 |
|
42 |
|
43 |
+
def get_jobs(
|
44 |
+
keys_range: list,
|
45 |
+
bucket: str,
|
46 |
+
prefix: str,
|
47 |
+
output_dir: Path,
|
48 |
+
metadata_dir: Path,
|
49 |
+
metadata_only: bool,
|
50 |
+
) -> List[Job]:
|
51 |
client = boto3.client("s3")
|
52 |
|
53 |
token, jobs = None, []
|
|
|
59 |
|
60 |
for item in response.get("Contents"):
|
61 |
key = Path(item["Key"])
|
62 |
+
if key.suffix == ".tar" and int(key.stem) in keys_range and not metadata_only:
|
63 |
jobs.append(Job(bucket, key, output_dir))
|
64 |
+
elif key.suffix == ".parquet" and int(key.stem) in keys_range:
|
65 |
+
jobs.append(Job(bucket, key, metadata_dir))
|
66 |
|
67 |
if not response["IsTruncated"]: break
|
68 |
token = response["NextContinuationToken"]
|
|
|
72 |
|
73 |
def download_and_extract(job: Job) -> None:
|
74 |
client = boto3.client("s3")
|
75 |
+
filename = job.output_dir / job.key.name
|
76 |
|
77 |
print(f"{job.key}: downloading...")
|
78 |
+
client.download_file(job.bucket, str(job.key), filename)
|
79 |
+
|
80 |
+
if filename.suffix == ".tar":
|
81 |
+
print(f"{job.key}: extracting...")
|
82 |
+
with tarfile.open(filename) as tar:
|
83 |
+
for name in tar.getnames():
|
84 |
+
extracted_path = job.output_dir/name
|
85 |
+
if extracted_path.exists():
|
86 |
+
continue
|
87 |
+
if name.endswith(".jpg"):
|
88 |
+
tar.extract(name, job.output_dir)
|
89 |
+
filename.unlink()
|
90 |
+
|
91 |
print(f"{job.key}: done!")
|
92 |
+
|
93 |
|
94 |
|
95 |
if __name__ == "__main__":
|
realfake/utils.py
CHANGED
@@ -122,5 +122,8 @@ def find_latest_checkpoint(dirname: Path) -> Path:
|
|
122 |
return latest
|
123 |
|
124 |
|
125 |
-
def list_files(dirname: Path, exts: list[str]) -> list:
|
126 |
-
|
|
|
|
|
|
|
|
122 |
return latest
|
123 |
|
124 |
|
125 |
+
def list_files(dirname: Path, exts: list[str] | None = None) -> list:
|
126 |
+
files = Path(dirname).iterdir()
|
127 |
+
if not exts:
|
128 |
+
return list(files)
|
129 |
+
return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")]
|
submit_movie.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
# SLURM SUBMIT SCRIPT
|
4 |
+
#SBATCH --partition=g40
|
5 |
+
#SBATCH --nodes=1
|
6 |
+
#SBATCH --gpus=8
|
7 |
+
#SBATCH --cpus-per-gpu=6
|
8 |
+
#SBATCH --job-name=realfake
|
9 |
+
#SBATCH --comment=laion
|
10 |
+
#SBATCH --signal=SIGUSR1@90
|
11 |
+
|
12 |
+
source "${HOME}/venv/bin/activate"
|
13 |
+
|
14 |
+
export NCCL_DEBUG=INFO
|
15 |
+
export PYTHONFAULTHANDLER=1
|
16 |
+
export PYTHONPATH="${HOME}/realfake"
|
17 |
+
|
18 |
+
echo "Working directory: `pwd`"
|
19 |
+
|
20 |
+
srun python3 realfake/train_cluster.py \
|
21 |
+
-jf "${HOME}/realfake/metadata/movies_plus.jsonl" \
|
22 |
+
-mn convnext_small -e=40 -fe=40 -bs=128 -wl=1 -fw=0.08 \
|
23 |
+
--acceleratorparams.devices=8 \
|
24 |
+
--acceleratorparams.strategy=ddp_find_unused_parameters_false
|