Petr Tsvetkov commited on
Commit
e027012
1 Parent(s): f1b08a8

Add checkpoints

Browse files
config.py CHANGED
@@ -24,4 +24,8 @@ CACHE_DIR.mkdir(exist_ok=True)
24
  OUTPUT_DIR = Path("output")
25
  OUTPUT_DIR.mkdir(exist_ok=True)
26
 
 
 
 
27
  SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
 
 
24
  OUTPUT_DIR = Path("output")
25
  OUTPUT_DIR.mkdir(exist_ok=True)
26
 
27
+
28
+ END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
29
+ START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
30
  SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
31
+ METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
custom_metrics/__init__.py ADDED
File without changes
custom_metrics/gpt_eval.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from api_wrappers import grazie_wrapper
4
+
5
+
6
+ def build_prompt(prediction, reference):
7
+ return f"""Your task is to rate the quality of the generated commit message using the scale from 1 to 5.
8
+
9
+ A good commit message has to be concise.
10
+ Assign lower scores for the commit messages that are too verbose for a commit message.
11
+
12
+ The generated commit message you have to evaluate:
13
+ START OF THE GENERATED COMMIT MESSAGE
14
+ {prediction}
15
+ END OF THE GENERATED COMMIT MESSAGE
16
+
17
+ Here is an example of an ideal reference commit message for the same commit:
18
+ START OF THE REFERENCE COMMIT MESSAGE
19
+ {reference}
20
+ END OF THE REFERENCE COMMIT MESSAGE
21
+
22
+ All the information in the reference commit message is true.
23
+
24
+ Print only one integer number after the token "OUTPUT" - the rating of the generated commit message.
25
+ Do not print anything that is not an integer.
26
+
27
+ OUTPUT
28
+ """
29
+
30
+
31
+ N_RETRIES = 3
32
+
33
+
34
+ def compute(prediction, reference):
35
+ prompt = build_prompt(prediction, reference)
36
+ outputs = []
37
+
38
+ for i in range(N_RETRIES):
39
+ try:
40
+ output = grazie_wrapper.generate_for_prompt(prompt).strip()[-1]
41
+ outputs.append(output)
42
+ return int(output)
43
+ except ValueError:
44
+ continue
45
+
46
+ raise RuntimeError(f"GPT4 cannot generate a number. Its outputs were: {str(outputs)}")
generation_steps/metrics_analysis.py CHANGED
@@ -1,9 +1,13 @@
 
 
 
1
  import evaluate
2
  import pandas as pd
3
  from tqdm import tqdm
4
 
5
  import config
6
  from api_wrappers import hf_data_loader
 
7
 
8
  BLEU = evaluate.load('bleu', cache_dir=config.CACHE_DIR)
9
 
@@ -37,12 +41,17 @@ def bertscore_fn(pred, ref):
37
  return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]
38
 
39
 
 
 
 
 
40
  METRICS = {
 
41
  "bleu": bleu_fn,
42
  "meteor": meteor_fn,
43
  "rouge1": rouge1_fn,
44
  "rouge2": rouge2_fn,
45
- "bertscore": bertscore_fn
46
  }
47
 
48
 
@@ -82,20 +91,40 @@ def compute_metrics(df):
82
  return df
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def transform(df):
86
  print("Computing metrics")
87
 
88
  df = attach_references(df)
89
  df = compute_metrics(df)
90
 
 
 
 
 
 
91
  print("Done")
92
  return df
93
 
94
 
95
  def main():
96
- df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
97
- df = transform(df)
98
- df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
99
 
100
 
101
  if __name__ == '__main__':
 
1
+ import functools
2
+ import operator
3
+
4
  import evaluate
5
  import pandas as pd
6
  from tqdm import tqdm
7
 
8
  import config
9
  from api_wrappers import hf_data_loader
10
+ from custom_metrics import gpt_eval
11
 
12
  BLEU = evaluate.load('bleu', cache_dir=config.CACHE_DIR)
13
 
 
41
  return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]
42
 
43
 
44
+ def gptscore_fn(pred, ref):
45
+ return gpt_eval.compute(prediction=pred, reference=ref)
46
+
47
+
48
  METRICS = {
49
+ "gptscore": gptscore_fn,
50
  "bleu": bleu_fn,
51
  "meteor": meteor_fn,
52
  "rouge1": rouge1_fn,
53
  "rouge2": rouge2_fn,
54
+ "bertscore": bertscore_fn,
55
  }
56
 
57
 
 
91
  return df
92
 
93
 
94
+ def correlations_for_group(group):
95
+ correlations = []
96
+ for metric in METRICS:
97
+ correlations.append({
98
+ f"{metric}_pearson": group[f"{metric}_related"].corr(group[f"{metric}_independent"], method="pearson"),
99
+ f"{metric}_spearman": group[f"{metric}_related"].corr(group[f"{metric}_independent"], method="spearman")
100
+ })
101
+ return pd.Series(functools.reduce(operator.ior, correlations, {}))
102
+
103
+
104
+ def compute_correlations(df: pd.DataFrame):
105
+ grouped_df = df.groupby(by=["end_to_start", "start_to_end"])
106
+ correlations = grouped_df.apply(correlations_for_group, include_groups=False)
107
+ return correlations
108
+
109
+
110
  def transform(df):
111
  print("Computing metrics")
112
 
113
  df = attach_references(df)
114
  df = compute_metrics(df)
115
 
116
+ correlations_for_groups = compute_correlations(df)
117
+ correlations_for_groups.to_csv(config.METRICS_CORRELATIONS_ARTIFACT)
118
+
119
+ df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
120
+
121
  print("Done")
122
  return df
123
 
124
 
125
  def main():
126
+ df = pd.read_csv(config.START_TO_END_ARTIFACT, index_col=[0])
127
+ transform(df)
 
128
 
129
 
130
  if __name__ == '__main__':
generation_steps/synthetic_end_to_start.py CHANGED
@@ -4,7 +4,7 @@ from tqdm import tqdm
4
  import config
5
  import generate_annotated_diffs
6
  import statistics
7
- from api_wrappers import grazie_wrapper
8
  from generation_steps import examples
9
 
10
  GENERATION_MULTIPLIER = 3
@@ -91,15 +91,15 @@ def transform(df):
91
  generated_df['end_to_start'] = True
92
 
93
  result = pd.concat([df, generated_df], ignore_index=True)
 
94
 
95
  print("Done")
96
  return result
97
 
98
 
99
  def main():
100
- df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
101
- df = transform(df)
102
- df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
103
 
104
 
105
  if __name__ == '__main__':
 
4
  import config
5
  import generate_annotated_diffs
6
  import statistics
7
+ from api_wrappers import grazie_wrapper, hf_data_loader
8
  from generation_steps import examples
9
 
10
  GENERATION_MULTIPLIER = 3
 
91
  generated_df['end_to_start'] = True
92
 
93
  result = pd.concat([df, generated_df], ignore_index=True)
94
+ result.to_csv(config.END_TO_START_ARTIFACT)
95
 
96
  print("Done")
97
  return result
98
 
99
 
100
  def main():
101
+ df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
102
+ transform(df)
 
103
 
104
 
105
  if __name__ == '__main__':
generation_steps/synthetic_start_to_end.py CHANGED
@@ -91,15 +91,15 @@ def transform(df):
91
  generated_df['start_to_end'] = True
92
 
93
  result = pd.concat([df, generated_df], ignore_index=True)
 
94
 
95
  print("Done")
96
  return result
97
 
98
 
99
  def main():
100
- df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
101
- df = transform(df)
102
- df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
103
 
104
 
105
  if __name__ == '__main__':
 
91
  generated_df['start_to_end'] = True
92
 
93
  result = pd.concat([df, generated_df], ignore_index=True)
94
+ result.to_csv(config.START_TO_END_ARTIFACT)
95
 
96
  print("Done")
97
  return result
98
 
99
 
100
  def main():
101
+ df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0])
102
+ transform(df)
 
103
 
104
 
105
  if __name__ == '__main__':