climateGAN / utils_scripts /download_comet_images.py
vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
raw
history blame
9.33 kB
import argparse
import os
from collections import Counter
from pathlib import Path
import comet_ml
import yaml
from addict import Dict
from comet_ml import config
def parse_tags(tags_str):
all_tags = set(t.strip() for t in tags_str.split(","))
keep_tags = set()
remove_tags = set()
for t in all_tags:
if "!" in t or "~" in t:
remove_tags.add(t[1:])
else:
keep_tags.add(t)
return all_tags, keep_tags, remove_tags
def select_lambdas(vars):
"""
Create a specific file with the painter's lambdas
Args:
vars (dict): output of locals()
"""
opts = vars["opts"]
dev = vars["args"].dev
lambdas = opts.train.lambdas.G.p
if not dev:
with open("./painter_lambdas.yaml", "w") as f:
yaml.safe_dump(lambdas.to_dict(), f)
def parse_value(v: str):
"""
Parses a string into bool or list or int or float or returns it as is
Args:
v (str): value to parse
Returns:
any: parsed value
"""
if v.lower() == "false":
return False
if v.lower() == "true":
return True
if v.startswith("[") and v.endswith("]"):
return [
parse_value(sub_v)
for sub_v in v.replace("[", "").replace("]", "").split(", ")
]
if "." in v:
try:
vv = float(v)
return vv
except ValueError:
return v
else:
try:
vv = int(v)
return vv
except ValueError:
return v
def parse_opts(summary):
"""
Parses a flatten_opts summary into an addict.Dict
Args:
summary (list(dict)): List of dicts from exp.get_parameters_summary()
Returns:
addict.Dict: parsed exp params
"""
opts = Dict()
for item in summary:
k, v = item["name"], parse_value(item["valueCurrent"])
if "." in k:
d = opts
for subkey in k.split(".")[:-1]:
d = d[subkey]
d[k.split(".")[-1]] = v
else:
opts[k] = v
return opts
def has_right_tags(exp: comet_ml.Experiment, keep: set, remove: set) -> bool:
"""
All the "keep" tags should be in the experiment's tags
None of the "remove" tags should be in the experiment's tags.
Args:
exp (comet_ml.Experiment): experiment to select (or not)
keep (set): tags the exp should have
remove (set): tags the exp cannot have
Returns:
bool: should this exp be selected
"""
tags = set(exp.get_tags())
has_all_keep = keep.intersection(tags) == keep
has_any_remove = remove.intersection(tags)
return has_all_keep and not has_any_remove
if __name__ == "__main__":
# ------------------------
# ----- Parse args -----
# ------------------------
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_id", type=str, default="")
parser.add_argument(
"-d",
"--download_dir",
type=str,
default=None,
help="Where to download the images",
)
parser.add_argument(
"-s", "--step", default="last", type=str, help="`last`, `all` or `int`"
)
parser.add_argument(
"-b",
"--base_dir",
default="./",
type=str,
help="if download_dir is not specified, download into base_dir/exp_id[:8]/",
)
parser.add_argument(
"-t",
"--tags",
default="",
type=str,
help="download all images of all with a set of tags",
)
parser.add_argument(
"-i",
"--id_length",
default=8,
type=int,
help="Length of the experiment's ID substring to make dirs: exp.id[:id_length]",
)
parser.add_argument(
"--dev",
default=False,
action="store_true",
help="dry run: no mkdir, no download",
)
parser.add_argument(
"-p",
"--post_processings",
default="",
type=str,
help="comma separated string list of post processing functions to apply",
)
parser.add_argument(
"-r",
"--running",
default=False,
action="store_true",
help="only select running exps",
)
args = parser.parse_args()
print(args)
# -------------------------------------
# ----- Create post processings -----
# -------------------------------------
POST_PROCESSINGS = {"select_lambdas": select_lambdas}
post_processes = list(
filter(
lambda p: p is not None,
[POST_PROCESSINGS.get(k.strip()) for k in args.post_processings.split(",")],
)
)
# ------------------------------------------------------
# ----- Create Download Dir from download_dir or -----
# ----- base_dir/exp_id[:args.id_length] -----
# ------------------------------------------------------
download_dir = Path(args.download_dir or Path(args.base_dir)).resolve()
if not args.dev:
download_dir.mkdir(parents=True, exist_ok=True)
# ------------------------
# ----- Check step -----
# ------------------------
step = None
try:
step = int(args.step)
except ValueError:
step = args.step
assert step in {"last", "all"}
api = comet_ml.api.API()
# ---------------------------------------
# ----- Select exps based on tags -----
# ---------------------------------------
if not args.tags:
assert args.exp_id
exps = [api.get_experiment_by_id(args.exp_id)]
else:
all_tags, keep_tags, remove_tags = parse_tags(args.tags)
download_dir = download_dir / "&".join(sorted(all_tags))
print("Selecting experiments with tags", all_tags)
conf = dict(config.get_config())
exps = api.get_experiments(
workspace=conf.get("comet.workspace"),
project_name=conf.get("comet.project_name") or "climategan",
)
exps = list(filter(lambda e: has_right_tags(e, keep_tags, remove_tags), exps))
if args.running:
exps = [e for e in exps if e.alive]
# -------------------------
# ----- Print setup -----
# -------------------------
print(
"Processing {} experiments in {} with post processes {}".format(
len(exps), str(download_dir), post_processes
)
)
assert all(
[v == 1 for v in Counter([e.id[: args.id_length] for e in exps]).values()]
), "Experiment ID conflict, use a larger --id_length"
for e, exp in enumerate(exps):
# ----------------------------------------------
# ----- Setup Current Download Directory -----
# ----------------------------------------------
cropped_id = exp.id[: args.id_length]
ddir = (download_dir / cropped_id).resolve()
if not args.dev:
ddir.mkdir(parents=True, exist_ok=True)
# ------------------------------
# ----- Fetch image list -----
# ------------------------------
ims = [asset for asset in exp.get_asset_list() if asset["image"] is True]
# -----------------------------------
# ----- Filter images by step -----
# -----------------------------------
if step == "last":
curr_step = max(i["step"] or -1 for i in ims)
if curr_step == -1:
curr_step = None
else:
curr_step = step
ims = [i for i in ims if (i["step"] == curr_step) or (step == "all")]
ddir = ddir / str(curr_step)
if not args.dev:
ddir.mkdir(parents=True, exist_ok=True)
# ----------------------------------------------
# ----- Store experiment's link and opts -----
# ----------------------------------------------
summary = exp.get_parameters_summary()
opts = parse_opts(summary)
if not args.dev:
with open("./url.txt", "w") as f:
f.write(exp.url)
with open("./opts.yaml", "w") as f:
yaml.safe_dump(opts.to_dict(), f)
# ------------------------------------------
# ----- Download png files with curl -----
# ------------------------------------------
print(
" >>> Downloading exp {}'s image at step `{}` into {}".format(
cropped_id, args.step, str(ddir)
)
)
for i, im in enumerate(ims):
if not Path(im["fileName"] + "_{}.png".format(curr_step)).exists():
print(
"\nDownloading exp {}/{} image {}/{}: {} in {}".format(
e + 1, len(exps), i + 1, len(ims), im["fileName"], ddir
)
)
if not args.dev:
assert len(im["curlDownload"].split(" > ")) == 2
curl_command = im["curlDownload"].split(" > ")[0]
file_stem = Path(im["curlDownload"].split(" > ")[1]).stem
file_path = (
f'"{str(ddir / file_stem)}_{cropped_id}_{curr_step}.png"'
)
signal = os.system(f"{curl_command} > {file_path}")
for p in post_processes:
p(locals())