Petr Tsvetkov commited on
Commit
f1b08a8
β€’
1 Parent(s): 9d943c1

Compute & compare metrics

Browse files
generate_synthetic_dataset.py CHANGED
@@ -1,26 +1,14 @@
1
  import config
2
  from api_wrappers import hf_data_loader
3
- from generation_steps import synthetic_end_to_start, examples, synthetic_start_to_end
4
 
5
 
6
  def run():
7
  df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
8
- print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
9
- print()
10
 
11
- print(f"End -> start synthesis:")
12
- print(f"GENERATION_MULTIPLIER = {synthetic_end_to_start.GENERATION_MULTIPLIER}")
13
- print(f"REL_INSERTIONS_THRESHOLD = {synthetic_end_to_start.REL_INSERTIONS_THRESHOLD}")
14
- print(f"GENERATION_ATTEMPTS = {synthetic_end_to_start.GENERATION_ATTEMPTS}")
15
  df = synthetic_end_to_start.transform(df)
16
- print("Done")
17
-
18
- print(f"Start -> send synthesis:")
19
- print(f"GENERATION_MULTIPLIER = {synthetic_start_to_end.GENERATION_MULTIPLIER}")
20
- print(f"REL_DELETIONS_THRESHOLD = {synthetic_start_to_end.REL_DELETIONS_THRESHOLD}")
21
- print(f"GENERATION_ATTEMPTS = {synthetic_start_to_end.GENERATION_ATTEMPTS}")
22
  df = synthetic_start_to_end.transform(df)
23
- print("Done")
24
 
25
  df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
26
 
 
1
  import config
2
  from api_wrappers import hf_data_loader
3
+ from generation_steps import synthetic_end_to_start, synthetic_start_to_end, metrics_analysis
4
 
5
 
6
  def run():
7
  df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
 
 
8
 
 
 
 
 
9
  df = synthetic_end_to_start.transform(df)
 
 
 
 
 
 
10
  df = synthetic_start_to_end.transform(df)
11
+ df = metrics_analysis.transform(df)
12
 
13
  df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
14
 
generation_steps/metrics_analysis.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import pandas as pd
3
+ from tqdm import tqdm
4
+
5
+ import config
6
+ from api_wrappers import hf_data_loader
7
+
8
+ BLEU = evaluate.load('bleu', cache_dir=config.CACHE_DIR)
9
+
10
+
11
+ def bleu_fn(pred, ref):
12
+ return BLEU.compute(predictions=[pred], references=[ref])["bleu"]
13
+
14
+
15
+ METEOR = evaluate.load('meteor', cache_dir=config.CACHE_DIR)
16
+
17
+
18
+ def meteor_fn(pred, ref):
19
+ return METEOR.compute(predictions=[pred], references=[ref])["meteor"]
20
+
21
+
22
+ ROUGE = evaluate.load('rouge', cache_dir=config.CACHE_DIR)
23
+
24
+
25
+ def rouge1_fn(pred, ref):
26
+ return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"]
27
+
28
+
29
+ def rouge2_fn(pred, ref):
30
+ return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"]
31
+
32
+
33
+ BERTSCORE = evaluate.load('bertscore', cache_dir=config.CACHE_DIR)
34
+
35
+
36
+ def bertscore_fn(pred, ref):
37
+ return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]
38
+
39
+
40
+ METRICS = {
41
+ "bleu": bleu_fn,
42
+ "meteor": meteor_fn,
43
+ "rouge1": rouge1_fn,
44
+ "rouge2": rouge2_fn,
45
+ "bertscore": bertscore_fn
46
+ }
47
+
48
+
49
+ def attach_references(df):
50
+ reference_df = hf_data_loader.load_full_commit_dataset_as_pandas().set_index(["hash", "repo"])[["reference"]]
51
+ df = df.set_index(["hash", "repo"])
52
+ return df.join(other=reference_df, how="left").reset_index()
53
+
54
+
55
+ def compute_metrics(df):
56
+ tqdm.pandas()
57
+
58
+ def apply_metric_fn_to_row(row, fn, col_pred, col_ref):
59
+ return fn(row[col_pred], row[col_ref])
60
+
61
+ for metric in METRICS:
62
+ print(f"Computing {metric}")
63
+ metric_fn = METRICS[metric]
64
+ df[f"{metric}_related"] = df.progress_apply(
65
+ lambda row: apply_metric_fn_to_row(row=row,
66
+ fn=metric_fn,
67
+ col_pred="commit_msg_start",
68
+ col_ref="commit_msg_end"),
69
+ axis=1
70
+ )
71
+ df[f"{metric}_independent"] = df.progress_apply(
72
+ lambda row: apply_metric_fn_to_row(row=row,
73
+ fn=metric_fn,
74
+ col_pred="commit_msg_start",
75
+ col_ref="reference"),
76
+ axis=1
77
+ )
78
+
79
+ df[f"{metric}_pearson"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="pearson")
80
+ df[f"{metric}_spearman"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="spearman")
81
+
82
+ return df
83
+
84
+
85
+ def transform(df):
86
+ print("Computing metrics")
87
+
88
+ df = attach_references(df)
89
+ df = compute_metrics(df)
90
+
91
+ print("Done")
92
+ return df
93
+
94
+
95
+ def main():
96
+ df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
97
+ df = transform(df)
98
+ df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
99
+
100
+
101
+ if __name__ == '__main__':
102
+ main()
generation_steps/synthetic_end_to_start.py CHANGED
@@ -1,12 +1,13 @@
1
  import pandas as pd
2
  from tqdm import tqdm
3
 
 
4
  import generate_annotated_diffs
5
  import statistics
6
  from api_wrappers import grazie_wrapper
7
  from generation_steps import examples
8
 
9
- GENERATION_MULTIPLIER = 1
10
  REL_INSERTIONS_THRESHOLD = 0.5
11
  GENERATION_ATTEMPTS = 5
12
 
@@ -62,6 +63,12 @@ COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"]
62
 
63
 
64
  def transform(df):
 
 
 
 
 
 
65
  df['end_to_start'] = False
66
 
67
  generated_data = {
@@ -83,4 +90,17 @@ def transform(df):
83
  generated_df = pd.DataFrame.from_dict(generated_data)
84
  generated_df['end_to_start'] = True
85
 
86
- return pd.concat([df, generated_df], ignore_index=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  from tqdm import tqdm
3
 
4
+ import config
5
  import generate_annotated_diffs
6
  import statistics
7
  from api_wrappers import grazie_wrapper
8
  from generation_steps import examples
9
 
10
+ GENERATION_MULTIPLIER = 3
11
  REL_INSERTIONS_THRESHOLD = 0.5
12
  GENERATION_ATTEMPTS = 5
13
 
 
63
 
64
 
65
  def transform(df):
66
+ print(f"End -> start synthesis:")
67
+ print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
68
+ print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
69
+ print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}")
70
+ print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
71
+
72
  df['end_to_start'] = False
73
 
74
  generated_data = {
 
90
  generated_df = pd.DataFrame.from_dict(generated_data)
91
  generated_df['end_to_start'] = True
92
 
93
+ result = pd.concat([df, generated_df], ignore_index=True)
94
+
95
+ print("Done")
96
+ return result
97
+
98
+
99
+ def main():
100
+ df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
101
+ df = transform(df)
102
+ df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
generation_steps/synthetic_start_to_end.py CHANGED
@@ -1,12 +1,13 @@
1
  import pandas as pd
2
  from tqdm import tqdm
3
 
 
4
  import generate_annotated_diffs
5
  import statistics
6
  from api_wrappers import grazie_wrapper
7
  from generation_steps import examples
8
 
9
- GENERATION_MULTIPLIER = 1
10
  REL_DELETIONS_THRESHOLD = 0.75
11
  GENERATION_ATTEMPTS = 5
12
 
@@ -62,6 +63,12 @@ COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_s
62
 
63
 
64
  def transform(df):
 
 
 
 
 
 
65
  df['start_to_end'] = False
66
 
67
  generated_data = {
@@ -83,4 +90,17 @@ def transform(df):
83
  generated_df = pd.DataFrame.from_dict(generated_data)
84
  generated_df['start_to_end'] = True
85
 
86
- return pd.concat([df, generated_df], ignore_index=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  from tqdm import tqdm
3
 
4
+ import config
5
  import generate_annotated_diffs
6
  import statistics
7
  from api_wrappers import grazie_wrapper
8
  from generation_steps import examples
9
 
10
+ GENERATION_MULTIPLIER = 3
11
  REL_DELETIONS_THRESHOLD = 0.75
12
  GENERATION_ATTEMPTS = 5
13
 
 
63
 
64
 
65
  def transform(df):
66
+ print(f"Start -> send synthesis:")
67
+ print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
68
+ print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
69
+ print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
70
+ print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
71
+
72
  df['start_to_end'] = False
73
 
74
  generated_data = {
 
90
  generated_df = pd.DataFrame.from_dict(generated_data)
91
  generated_df['start_to_end'] = True
92
 
93
+ result = pd.concat([df, generated_df], ignore_index=True)
94
+
95
+ print("Done")
96
+ return result
97
+
98
+
99
+ def main():
100
+ df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
101
+ df = transform(df)
102
+ df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()