Spaces:
Configuration error
Configuration error
export LOGLEVEL=INFO | |
output_dir=output | |
#### Infer Hyper | |
default_step=20 # inference step for diffusion model | |
default_bs=50 # batch size for inference | |
default_sample_nums=30000 # inference first $sample_nums sample in list(json.keys()) | |
default_sampling_algo="dpm-solver" | |
default_json_file="data/test/PG-eval-data/MJHQ-30K/meta_data.json" # MJHQ-30K meta json | |
default_add_label='' | |
#### Metrics Hyper | |
default_img_size=512 # 1024 # img size for fid reference embedding | |
default_fid_suffix_label='' # suffix of the line chart on wandb | |
default_log_fid=false #false | |
default_log_clip_score=false | |
default_log_image_reward=false | |
default_log_dpg=false | |
# 👇No need to change the code below | |
if [ -n "$1" ]; then | |
config_file=$1 | |
fi | |
if [ -n "$2" ]; then | |
model_paths_file=$2 | |
fi | |
for arg in "$@" | |
do | |
case $arg in | |
--np=*) | |
np="${arg#*=}" | |
shift | |
;; | |
--step=*) | |
step="${arg#*=}" | |
shift | |
;; | |
--bs=*) | |
bs="${arg#*=}" | |
shift | |
;; | |
--sample_nums=*) | |
sample_nums="${arg#*=}" | |
shift | |
;; | |
--sampling_algo=*) | |
sampling_algo="${arg#*=}" | |
shift | |
;; | |
--json_file=*) | |
json_file="${arg#*=}" | |
shift | |
;; | |
--exist_time_prefix=*) | |
exist_time_prefix="${arg#*=}" | |
shift | |
;; | |
--img_size=*) | |
img_size="${arg#*=}" | |
shift | |
;; | |
--dataset=*) | |
dataset="${arg#*=}" | |
shift | |
;; | |
--cfg_scale=*) | |
cfg_scale="${arg#*=}" | |
shift | |
;; | |
--fid_suffix_label=*) | |
fid_suffix_label="${arg#*=}" | |
shift | |
;; | |
--add_label=*) | |
add_label="${arg#*=}" | |
shift | |
;; | |
--log_fid=*) | |
log_fid="${arg#*=}" | |
shift | |
;; | |
--log_clip_score=*) | |
log_clip_score="${arg#*=}" | |
shift | |
;; | |
--log_image_reward=*) | |
log_image_reward="${arg#*=}" | |
shift | |
;; | |
--log_dpg=*) | |
log_dpg="${arg#*=}" | |
shift | |
;; | |
--inference=*) | |
inference="${arg#*=}" | |
shift | |
;; | |
--fid=*) | |
fid="${arg#*=}" | |
shift | |
;; | |
--clipscore=*) | |
clipscore="${arg#*=}" | |
shift | |
;; | |
--imagereward=*) | |
imagereward="${arg#*=}" | |
shift | |
;; | |
--dpg=*) | |
dpg="${arg#*=}" | |
shift | |
;; | |
--output_dir=*) | |
output_dir="${arg#*=}" | |
shift | |
;; | |
--auto_ckpt=*) | |
auto_ckpt="${arg#*=}" | |
shift | |
;; | |
--auto_ckpt_interval=*) | |
auto_ckpt_interval="${arg#*=}" | |
shift | |
;; | |
--tracker_pattern=*) | |
tracker_pattern="${arg#*=}" | |
shift | |
;; | |
--ablation_key=*) | |
ablation_key="${arg#*=}" | |
shift | |
;; | |
--ablation_selections=*) | |
ablation_selections="${arg#*=}" | |
shift | |
;; | |
--inference_script=*) | |
inference_script="${arg#*=}" | |
shift | |
;; | |
*) | |
;; | |
esac | |
done | |
inference=${inference:-true} | |
fid=${fid:-true} | |
clipscore=${clipscore:-true} | |
imagereward=${imagereward:-false} | |
dpg=${dpg:-false} | |
np=${np:-8} | |
step=${step:-$default_step} | |
bs=${bs:-$default_bs} | |
dataset=${dataset:-'custom'} | |
cfg_scale=${cfg_scale:-4.5} | |
sample_nums=${sample_nums:-$default_sample_nums} | |
sampling_algo=${sampling_algo:-$default_sampling_algo} | |
json_file=${json_file:-$default_json_file} | |
exist_time_prefix=${exist_time_prefix:-$default_exist_time_prefix} | |
add_label=${add_label:-$default_add_label} | |
ablation_key=${ablation_key:-''} | |
ablation_selections=${ablation_selections:-''} | |
img_size=${img_size:-$default_img_size} | |
fid_suffix_label=${fid_suffix_label:-$default_fid_suffix_label} | |
tracker_pattern=${tracker_pattern:-"epoch_step"} | |
log_fid=${log_fid:-$default_log_fid} | |
log_clip_score=${log_clip_score:-$default_log_clip_score} | |
log_image_reward=${log_image_reward:-$default_log_image_reward} | |
log_dpg=${log_dpg:-$default_log_dpg} | |
auto_ckpt=${auto_ckpt:-false} | |
auto_ckpt_interval=${auto_ckpt_interval:-0} | |
job_name=$(basename $(dirname $(dirname "$model_paths_file"))) | |
metric_dir=$output_dir/$job_name/metrics | |
if [ ! -d "$metric_dir" ]; then | |
echo "Creating directory: $metric_dir" | |
mkdir -p "$metric_dir" | |
fi | |
# select all the last step ckpts of one epoch to inference | |
if [ "$auto_ckpt" = true ]; then | |
bash scripts/collect_pth_path.sh $output_dir/$job_name/checkpoints $auto_ckpt_interval | |
fi | |
# ============ 1. start of inference block =========== | |
cache_file_path=$model_paths_file | |
if [ ! -e "$model_paths_file" ]; then | |
cache_file_path=$output_dir/$job_name/metrics/cached_img_paths_${dataset}.txt | |
echo "$model_paths_file not exists, use default image path: $cache_file_path" | |
fi | |
if [ "$inference" = true ]; then | |
inference_script=${inference_script:-"scripts/inference.py"} | |
cache_file_path=$output_dir/$job_name/metrics/cached_img_paths_${dataset}.txt | |
rm $metric_dir/tmp_${dataset}* || true | |
read -r -d '' cmd <<EOF | |
bash scripts/infer_run_inference.sh $config_file $model_paths_file --np=$np \ | |
--inference_script=$inference_script --step=$step --bs=$bs --sample_nums=$sample_nums --json_file=$json_file \ | |
--add_label=$add_label \ | |
--exist_time_prefix=$exist_time_prefix --if_save_dirname=true --sampling_algo=$sampling_algo \ | |
--cfg_scale=$cfg_scale --dataset=$dataset \ | |
--ablation_key=$ablation_key --ablation_selections="$ablation_selections" | |
EOF | |
echo $cmd | |
bash -c "${cmd}" | |
> "$cache_file_path" # clean file | |
# add all tmp*.txt file into $cache_file_path | |
for file in $metric_dir/tmp_${dataset}*.txt; do | |
if [ -f "$file" ]; then | |
cat "$file" >> "$cache_file_path" | |
echo "" >> "$cache_file_path" # add new line | |
fi | |
done | |
rm -r $metric_dir/tmp_${dataset}* || true | |
fi | |
img_path=${output_dir}/${job_name}/vis | |
exp_paths_file=${cache_file_path} | |
# ============ 2. start of fid block ================= | |
if [ "$fid" = true ]; then | |
read -r -d '' cmd <<EOF | |
bash tools/metrics/compute_fid_embedding.sh $img_path $exp_paths_file \ | |
--sample_nums=$sample_nums --img_size=$img_size --suffix_label=$fid_suffix_label \ | |
--log_fid=$log_fid --tracker_pattern=$tracker_pattern | |
EOF | |
echo $cmd | |
bash -c "${cmd}" | |
fi | |
# ============ 3. start of clip-score block ================= | |
if [ "$clipscore" = true ]; then | |
read -r -d '' cmd <<EOF | |
bash tools/metrics/compute_clipscore.sh $img_path $exp_paths_file \ | |
--sample_nums=$sample_nums --suffix_label=$fid_suffix_label \ | |
--log_clip_score=$log_clip_score --tracker_pattern=$tracker_pattern | |
EOF | |
echo $cmd | |
bash -c "${cmd}" | |
fi | |
# ============ 4. start of image-reward block ================= | |
if [ "$imagereward" = true ]; then | |
read -r -d '' cmd <<EOF | |
bash tools/metrics/compute_imagereward.sh $img_path $exp_paths_file \ | |
--sample_nums=$sample_nums --suffix_label=$fid_suffix_label \ | |
--log_image_reward=$log_image_reward --tracker_pattern=$tracker_pattern | |
EOF | |
echo $cmd | |
bash -c "${cmd}" | |
fi | |
# ============ 4. start of dpg-bench block ================= | |
if [ "$dpg" = true ]; then | |
read -r -d '' cmd <<EOF | |
bash tools/metrics/compute_dpg.sh $img_path $exp_paths_file \ | |
--sample_nums=$sample_nums --img_size=$img_size --suffix_label=$fid_suffix_label \ | |
--log_dpg=$log_dpg --tracker_pattern=$tracker_pattern | |
EOF | |
echo $cmd | |
bash -c "${cmd}" | |
fi | |