deepspeed / scripts /tools /parse_log_to_get_eval_loss.py
xingzhikb's picture
init
002bd9b
import re
import json
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
def extract_dicts(line):
return re.findall(r"{[^}]*}", line)
def main(log_file_path, output_dir=None):
if output_dir is None:
output_dir = os.path.dirname(log_file_path)
# Read log file and extract lines with "eval_loss"
with open(log_file_path, "r") as log_file:
log_lines = log_file.readlines()
eval_loss_pattern = r"{'eval_loss'"
extracted_dicts = []
last_epoch = 0
epoch_offset = 0
for line in log_lines:
dicts = extract_dicts(line)
for d in dicts:
if re.search(eval_loss_pattern, d):
# Replace single quotes with double quotes
d = d.replace("'", '"')
try:
parsed_dict = json.loads(d)
current_epoch = parsed_dict["epoch"]
# Check if the epoch number has reset
if current_epoch < last_epoch:
epoch_offset += last_epoch
parsed_dict["epoch"] += epoch_offset
last_epoch = current_epoch
extracted_dicts.append(parsed_dict)
except json.JSONDecodeError:
print(f"Could not parse line: {d}")
# Save extracted data to a CSV file
df = pd.DataFrame(extracted_dicts)
csv_file_path = os.path.join(output_dir, "output.csv")
df.to_csv(csv_file_path, index=False)
# Plot "eval_loss" values
plt.plot(df["epoch"], df["eval_loss"])
plt.xlabel("Epoch")
plt.ylabel("Eval Loss")
plt.title("Eval Loss vs Epoch")
plt.grid()
plot_file_path = os.path.join(output_dir, "plot.png")
plt.savefig(plot_file_path)
plt.show()
if __name__ == "__main__":
log_file_path = sys.argv[1]
if len(sys.argv) > 2:
output_dir = sys.argv[2]
main(log_file_path, output_dir)
else:
main(log_file_path)