|
import re |
|
import json |
|
import sys |
|
import os |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def extract_dicts(line): |
|
return re.findall(r"{[^}]*}", line) |
|
|
|
|
|
def main(log_file_path, output_dir=None): |
|
if output_dir is None: |
|
output_dir = os.path.dirname(log_file_path) |
|
|
|
|
|
with open(log_file_path, "r") as log_file: |
|
log_lines = log_file.readlines() |
|
|
|
eval_loss_pattern = r"{'eval_loss'" |
|
extracted_dicts = [] |
|
|
|
last_epoch = 0 |
|
epoch_offset = 0 |
|
|
|
for line in log_lines: |
|
dicts = extract_dicts(line) |
|
for d in dicts: |
|
if re.search(eval_loss_pattern, d): |
|
|
|
d = d.replace("'", '"') |
|
try: |
|
parsed_dict = json.loads(d) |
|
current_epoch = parsed_dict["epoch"] |
|
|
|
|
|
if current_epoch < last_epoch: |
|
epoch_offset += last_epoch |
|
|
|
parsed_dict["epoch"] += epoch_offset |
|
last_epoch = current_epoch |
|
|
|
extracted_dicts.append(parsed_dict) |
|
except json.JSONDecodeError: |
|
print(f"Could not parse line: {d}") |
|
|
|
|
|
df = pd.DataFrame(extracted_dicts) |
|
csv_file_path = os.path.join(output_dir, "output.csv") |
|
df.to_csv(csv_file_path, index=False) |
|
|
|
|
|
plt.plot(df["epoch"], df["eval_loss"]) |
|
plt.xlabel("Epoch") |
|
plt.ylabel("Eval Loss") |
|
plt.title("Eval Loss vs Epoch") |
|
plt.grid() |
|
|
|
plot_file_path = os.path.join(output_dir, "plot.png") |
|
plt.savefig(plot_file_path) |
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
log_file_path = sys.argv[1] |
|
if len(sys.argv) > 2: |
|
output_dir = sys.argv[2] |
|
main(log_file_path, output_dir) |
|
else: |
|
main(log_file_path) |
|
|