Soutrik commited on
Commit
3e4ba8b
1 Parent(s): 3749e6c

added gitignore

Browse files
.gitignore CHANGED
@@ -24,3 +24,9 @@ data/
24
  checkpoints/
25
  logs/
26
  /data
 
 
 
 
 
 
 
24
  checkpoints/
25
  logs/
26
  /data
27
+ artifacts/
28
+ artifacts/*
29
+ *png
30
+ *jpg
31
+ *jpeg
32
+ artifacts/image_prediction.png
artifacts/image_prediction.png CHANGED
src/create_artifacts.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import os
3
+ from src.utils.multirun_op import multirun_artifact_producer
4
+ import hydra
5
+ from omegaconf import DictConfig
6
+ from loguru import logger
7
+ from dotenv import load_dotenv, find_dotenv
8
+ import rootutils
9
+
10
+ # Load environment variables
11
+ load_dotenv(find_dotenv(".env"))
12
+
13
+ # Setup root directory
14
+ root = rootutils.setup_root(__file__, indicator=".project-root")
15
+
16
+
17
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
18
+ def create_artifacts(cfg: DictConfig):
19
+ base_path = os.path.join(cfg.paths.log_dir, "train", "runs")
20
+ logger.info(
21
+ f"Base path: {base_path} and artifact directory: {cfg.paths.artifact_dir}"
22
+ )
23
+ multirun_artifact_producer(base_path, cfg.paths.artifact_dir)
24
+
25
+
26
+ if __name__ == "__main__":
27
+ create_artifacts()
src/utils/multirun_op.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import yaml
4
+ import pandas as pd
5
+ import json
6
+ from datetime import datetime
7
+ from loguru import logger
8
+
9
+
10
+ def multirun_artifact_producer(base_path: str, output_path: str):
11
+ """Aggregate data from the latest run's csv folder and save to a JSON file."""
12
+ # Find the latest top-level run folder
13
+ latest_folder = max(glob.glob(os.path.join(base_path, "*")), key=os.path.getmtime)
14
+ if not os.path.isdir(latest_folder):
15
+ logger.error("No valid run folders found!")
16
+ return
17
+
18
+ # Initialize JSON structure
19
+ output_data = {}
20
+ # Process each sub-run directory within the latest run folder
21
+ for run_dir in os.listdir(latest_folder):
22
+ run_path = os.path.join(latest_folder, run_dir)
23
+ if os.path.isdir(run_path):
24
+ # Look for the latest folder in the csv subdirectory
25
+ csv_base_path = os.path.join(run_path, "csv")
26
+ if not os.path.isdir(csv_base_path):
27
+ logger.warning(f"No csv directory found in {run_path}. Skipping.")
28
+ continue
29
+
30
+ # Find the latest version folder in csv
31
+ latest_csv_folder = max(
32
+ glob.glob(os.path.join(csv_base_path, "version_*")),
33
+ key=os.path.getmtime,
34
+ )
35
+ if not os.path.isdir(latest_csv_folder):
36
+ logger.warning(
37
+ f"No valid version folder found in {csv_base_path}. Skipping."
38
+ )
39
+ continue
40
+
41
+ # Paths to files in the latest csv version folder
42
+ hparams_path = os.path.join(latest_csv_folder, "hparams.yaml")
43
+ metrics_path = os.path.join(latest_csv_folder, "metrics.csv")
44
+
45
+ # Check if necessary files exist
46
+ if not os.path.isfile(hparams_path) or not os.path.isfile(metrics_path):
47
+ logger.warning(
48
+ f"Missing hparams.yaml or metrics.csv in {latest_csv_folder}. Skipping."
49
+ )
50
+ continue
51
+
52
+ # Read hparams.yaml
53
+ with open(hparams_path, "r") as file:
54
+ hparams = yaml.safe_load(file)
55
+
56
+ # Read metrics.csv and calculate averages
57
+ metrics_df = pd.read_csv(metrics_path)
58
+ avg_train_acc = metrics_df["train_acc"].dropna().mean()
59
+ avg_val_acc = metrics_df["val_acc"].dropna().mean()
60
+
61
+ # Create JSON structure for this run
62
+ output_data[f"run_{run_dir}"] = {
63
+ "hparams": hparams,
64
+ "metrics": {"avg_train_acc": avg_train_acc, "avg_val_acc": avg_val_acc},
65
+ }
66
+
67
+ # Save aggregated data to JSON
68
+ os.makedirs(output_path, exist_ok=True)
69
+ output_file = os.path.join(
70
+ output_path, f"aggregated_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
71
+ )
72
+ logger.info(f"Saving aggregated data to {output_file}")
73
+ with open(output_file, "w") as json_file:
74
+ json.dump(output_data, json_file, indent=4)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ # Paths
79
+ base_path = "./logs/train/runs"
80
+ output_path = "./artifacts"
81
+ multirun_artifact_producer(base_path, output_path)