Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
•
3e4ba8b
1
Parent(s):
3749e6c
added gitignore
Browse files- .gitignore +6 -0
- artifacts/image_prediction.png +0 -0
- src/create_artifacts.py +27 -0
- src/utils/multirun_op.py +81 -0
.gitignore
CHANGED
@@ -24,3 +24,9 @@ data/
|
|
24 |
checkpoints/
|
25 |
logs/
|
26 |
/data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
checkpoints/
|
25 |
logs/
|
26 |
/data
|
27 |
+
artifacts/
|
28 |
+
artifacts/*
|
29 |
+
*png
|
30 |
+
*jpg
|
31 |
+
*jpeg
|
32 |
+
artifacts/image_prediction.png
|
artifacts/image_prediction.png
CHANGED
src/create_artifacts.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import os
|
3 |
+
from src.utils.multirun_op import multirun_artifact_producer
|
4 |
+
import hydra
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from loguru import logger
|
7 |
+
from dotenv import load_dotenv, find_dotenv
|
8 |
+
import rootutils
|
9 |
+
|
10 |
+
# Load environment variables
|
11 |
+
load_dotenv(find_dotenv(".env"))
|
12 |
+
|
13 |
+
# Setup root directory
|
14 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
15 |
+
|
16 |
+
|
17 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
|
18 |
+
def create_artifacts(cfg: DictConfig):
|
19 |
+
base_path = os.path.join(cfg.paths.log_dir, "train", "runs")
|
20 |
+
logger.info(
|
21 |
+
f"Base path: {base_path} and artifact directory: {cfg.paths.artifact_dir}"
|
22 |
+
)
|
23 |
+
multirun_artifact_producer(base_path, cfg.paths.artifact_dir)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
create_artifacts()
|
src/utils/multirun_op.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import yaml
|
4 |
+
import pandas as pd
|
5 |
+
import json
|
6 |
+
from datetime import datetime
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
|
10 |
+
def multirun_artifact_producer(base_path: str, output_path: str):
|
11 |
+
"""Aggregate data from the latest run's csv folder and save to a JSON file."""
|
12 |
+
# Find the latest top-level run folder
|
13 |
+
latest_folder = max(glob.glob(os.path.join(base_path, "*")), key=os.path.getmtime)
|
14 |
+
if not os.path.isdir(latest_folder):
|
15 |
+
logger.error("No valid run folders found!")
|
16 |
+
return
|
17 |
+
|
18 |
+
# Initialize JSON structure
|
19 |
+
output_data = {}
|
20 |
+
# Process each sub-run directory within the latest run folder
|
21 |
+
for run_dir in os.listdir(latest_folder):
|
22 |
+
run_path = os.path.join(latest_folder, run_dir)
|
23 |
+
if os.path.isdir(run_path):
|
24 |
+
# Look for the latest folder in the csv subdirectory
|
25 |
+
csv_base_path = os.path.join(run_path, "csv")
|
26 |
+
if not os.path.isdir(csv_base_path):
|
27 |
+
logger.warning(f"No csv directory found in {run_path}. Skipping.")
|
28 |
+
continue
|
29 |
+
|
30 |
+
# Find the latest version folder in csv
|
31 |
+
latest_csv_folder = max(
|
32 |
+
glob.glob(os.path.join(csv_base_path, "version_*")),
|
33 |
+
key=os.path.getmtime,
|
34 |
+
)
|
35 |
+
if not os.path.isdir(latest_csv_folder):
|
36 |
+
logger.warning(
|
37 |
+
f"No valid version folder found in {csv_base_path}. Skipping."
|
38 |
+
)
|
39 |
+
continue
|
40 |
+
|
41 |
+
# Paths to files in the latest csv version folder
|
42 |
+
hparams_path = os.path.join(latest_csv_folder, "hparams.yaml")
|
43 |
+
metrics_path = os.path.join(latest_csv_folder, "metrics.csv")
|
44 |
+
|
45 |
+
# Check if necessary files exist
|
46 |
+
if not os.path.isfile(hparams_path) or not os.path.isfile(metrics_path):
|
47 |
+
logger.warning(
|
48 |
+
f"Missing hparams.yaml or metrics.csv in {latest_csv_folder}. Skipping."
|
49 |
+
)
|
50 |
+
continue
|
51 |
+
|
52 |
+
# Read hparams.yaml
|
53 |
+
with open(hparams_path, "r") as file:
|
54 |
+
hparams = yaml.safe_load(file)
|
55 |
+
|
56 |
+
# Read metrics.csv and calculate averages
|
57 |
+
metrics_df = pd.read_csv(metrics_path)
|
58 |
+
avg_train_acc = metrics_df["train_acc"].dropna().mean()
|
59 |
+
avg_val_acc = metrics_df["val_acc"].dropna().mean()
|
60 |
+
|
61 |
+
# Create JSON structure for this run
|
62 |
+
output_data[f"run_{run_dir}"] = {
|
63 |
+
"hparams": hparams,
|
64 |
+
"metrics": {"avg_train_acc": avg_train_acc, "avg_val_acc": avg_val_acc},
|
65 |
+
}
|
66 |
+
|
67 |
+
# Save aggregated data to JSON
|
68 |
+
os.makedirs(output_path, exist_ok=True)
|
69 |
+
output_file = os.path.join(
|
70 |
+
output_path, f"aggregated_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
71 |
+
)
|
72 |
+
logger.info(f"Saving aggregated data to {output_file}")
|
73 |
+
with open(output_file, "w") as json_file:
|
74 |
+
json.dump(output_data, json_file, indent=4)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
# Paths
|
79 |
+
base_path = "./logs/train/runs"
|
80 |
+
output_path = "./artifacts"
|
81 |
+
multirun_artifact_producer(base_path, output_path)
|