Petr Tsvetkov commited on
Commit
6676c5a
β€’
1 Parent(s): 2d03034

Generate a dataset for the labeling app

Browse files
api_wrappers/grazie_wrapper.py CHANGED
@@ -10,7 +10,7 @@ import config
10
  client = GrazieApiGatewayClient(
11
  grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
12
  url=GrazieApiGatewayUrls.STAGING,
13
- auth_type=AuthType.SERVICE,
14
  grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
15
  )
16
 
 
10
  client = GrazieApiGatewayClient(
11
  grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
12
  url=GrazieApiGatewayUrls.STAGING,
13
+ auth_type=AuthType.USER,
14
  grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
15
  )
16
 
api_wrappers/hf_data_loader.py CHANGED
@@ -3,14 +3,14 @@ from datasets import load_dataset
3
  import config
4
 
5
 
6
- def load_raw_rewriting_dataset_as_pandas():
7
  return load_dataset(config.HF_RAW_DATASET_NAME,
8
  split=config.HF_RAW_DATASET_SPLIT,
9
  token=config.HF_TOKEN,
10
  cache_dir=config.CACHE_DIR).to_pandas()
11
 
12
 
13
- def load_full_commit_dataset_as_pandas():
14
  return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
15
  name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
16
  split=config.HF_FULL_COMMITS_DATASET_SPLIT,
@@ -18,19 +18,35 @@ def load_full_commit_dataset_as_pandas():
18
  columns={'message': 'reference'})
19
 
20
 
21
- def load_processed_rewriting_dataset_as_pandas():
22
- manual_rewriting = load_raw_rewriting_dataset_as_pandas()[
23
  ["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
24
  manual_rewriting.set_index(["hash", "repo"], inplace=True)
25
 
26
- mods_dataset = load_full_commit_dataset_as_pandas()[["hash", "repo", "mods"]]
27
  mods_dataset.set_index(["hash", "repo"], inplace=True)
28
 
29
  return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
30
 
31
 
32
- def load_synthetic_dataset_as_pandas():
33
  return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
34
  split=config.HF_SYNTHETIC_DATASET_SPLIT,
35
  token=config.HF_TOKEN,
36
  cache_dir=config.CACHE_DIR).to_pandas()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import config
4
 
5
 
6
+ def load_raw_rewriting_as_pandas():
7
  return load_dataset(config.HF_RAW_DATASET_NAME,
8
  split=config.HF_RAW_DATASET_SPLIT,
9
  token=config.HF_TOKEN,
10
  cache_dir=config.CACHE_DIR).to_pandas()
11
 
12
 
13
+ def load_full_commit_as_pandas():
14
  return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
15
  name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
16
  split=config.HF_FULL_COMMITS_DATASET_SPLIT,
 
18
  columns={'message': 'reference'})
19
 
20
 
21
+ def load_processed_rewriting_as_pandas():
22
+ manual_rewriting = load_raw_rewriting_as_pandas()[
23
  ["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
24
  manual_rewriting.set_index(["hash", "repo"], inplace=True)
25
 
26
+ mods_dataset = load_full_commit_as_pandas()[["hash", "repo", "mods"]]
27
  mods_dataset.set_index(["hash", "repo"], inplace=True)
28
 
29
  return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
30
 
31
 
32
+ def load_synthetic_as_pandas():
33
  return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
34
  split=config.HF_SYNTHETIC_DATASET_SPLIT,
35
  token=config.HF_TOKEN,
36
  cache_dir=config.CACHE_DIR).to_pandas()
37
+
38
+
39
+ def load_full_commit_with_predictions_as_pandas():
40
+ full_dataset = load_full_commit_as_pandas()
41
+ predictions_dataset = load_dataset(config.HF_PREDICTIONS_DATASET_NAME,
42
+ config.HF_PREDICTIONS_DATASET_SUBNAME,
43
+ split=config.HF_PREDICTIONS_DATASET_SPLIT,
44
+ cache_dir=config.CACHE_DIR
45
+ ).to_pandas().sample(frac=1, random_state=config.RANDOM_STATE
46
+ ).set_index(['hash', 'repo'])[["prediction"]]
47
+
48
+ predictions_dataset = predictions_dataset[~predictions_dataset.index.duplicated(keep='first')]
49
+
50
+ dataset = full_dataset.join(other=predictions_dataset, on=('hash', 'repo'))
51
+
52
+ return dataset.reset_index()
config.py CHANGED
@@ -15,6 +15,10 @@ HF_FULL_COMMITS_DATASET_NAME = "JetBrains-Research/lca-commit-message-generation
15
  HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
16
  HF_FULL_COMMITS_DATASET_SPLIT = "test"
17
 
 
 
 
 
18
  HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
19
  HF_SYNTHETIC_DATASET_SPLIT = 'train'
20
 
@@ -24,8 +28,8 @@ CACHE_DIR.mkdir(exist_ok=True)
24
  OUTPUT_DIR = Path("output")
25
  OUTPUT_DIR.mkdir(exist_ok=True)
26
 
27
-
28
  END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
29
  START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
30
  SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
31
  METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
 
 
15
  HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
16
  HF_FULL_COMMITS_DATASET_SPLIT = "test"
17
 
18
+ HF_PREDICTIONS_DATASET_NAME = "JetBrains-Research/lca-results"
19
+ HF_PREDICTIONS_DATASET_SUBNAME = "cmg_gpt_4_0613"
20
+ HF_PREDICTIONS_DATASET_SPLIT = "test"
21
+
22
  HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
23
  HF_SYNTHETIC_DATASET_SPLIT = 'train'
24
 
 
28
  OUTPUT_DIR = Path("output")
29
  OUTPUT_DIR.mkdir(exist_ok=True)
30
 
 
31
  END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
32
  START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
33
  SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
34
  METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
35
+ DATA_FOR_LABELING_ARTIFACT = OUTPUT_DIR / "data_for_labeling.csv"
generate_annotated_diffs.py CHANGED
@@ -26,14 +26,14 @@ def annotated_diff_for_row(row):
26
 
27
 
28
  def manual_data_with_annotated_diffs():
29
- df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()
30
  annotated = df.apply(annotated_diff_for_row, axis=1)
31
  df['annotated_diff'] = annotated
32
  return df
33
 
34
 
35
  def synthetic_data_with_annotated_diffs():
36
- df = hf_data_loader.load_synthetic_dataset_as_pandas()
37
  annotated = df.apply(annotated_diff_for_row, axis=1)
38
  df['annotated_diff'] = annotated
39
  return df
 
26
 
27
 
28
  def manual_data_with_annotated_diffs():
29
+ df = hf_data_loader.load_raw_rewriting_as_pandas()
30
  annotated = df.apply(annotated_diff_for_row, axis=1)
31
  df['annotated_diff'] = annotated
32
  return df
33
 
34
 
35
  def synthetic_data_with_annotated_diffs():
36
+ df = hf_data_loader.load_synthetic_as_pandas()
37
  annotated = df.apply(annotated_diff_for_row, axis=1)
38
  df['annotated_diff'] = annotated
39
  return df
generate_synthetic_dataset.py CHANGED
@@ -4,7 +4,7 @@ from generation_steps import synthetic_end_to_start, synthetic_start_to_end, met
4
 
5
 
6
  def run():
7
- df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
8
 
9
  df = synthetic_end_to_start.transform(df)
10
  df = synthetic_start_to_end.transform(df)
 
4
 
5
 
6
  def run():
7
+ df = hf_data_loader.load_processed_rewriting_as_pandas()
8
 
9
  df = synthetic_end_to_start.transform(df)
10
  df = synthetic_start_to_end.transform(df)
generation_steps/examples.py CHANGED
@@ -36,7 +36,7 @@ END OF THE IMPROVED COMMIT MESSAGE
36
  END OF THE EXAMPLE"""
37
 
38
 
39
- manual_df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()[['commit_msg_start', 'commit_msg_end']]
40
  manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
41
 
42
 
 
36
  END OF THE EXAMPLE"""
37
 
38
 
39
+ manual_df = hf_data_loader.load_raw_rewriting_as_pandas()[['commit_msg_start', 'commit_msg_end']]
40
  manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
41
 
42
 
generation_steps/for_labeling.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from tqdm import tqdm
4
+
5
+ import config
6
+ from api_wrappers import hf_data_loader
7
+ from generation_steps import synthetic_start_to_end
8
+
9
+
10
+ def transform(df):
11
+ print(f"Generating data for labeling:")
12
+ synthetic_start_to_end.print_config()
13
+ tqdm.pandas()
14
+
15
+ manual_df = hf_data_loader.load_raw_rewriting_as_pandas()
16
+
17
+ manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE
18
+ ).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]]
19
+
20
+ manual_df = manual_df[~manual_df.index.duplicated(keep='first')]
21
+
22
+ def get_is_manually_rewritten(row):
23
+ commit_id = (row['hash'], row['repo'])
24
+ return commit_id in manual_df.index
25
+
26
+ result = df
27
+ result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1)
28
+
29
+ def get_prediction_message(row):
30
+ commit_id = (row['hash'], row['repo'])
31
+ if row['manual_sample']:
32
+ return manual_df.loc[commit_id]['commit_msg_start']
33
+ return row['prediction']
34
+
35
+ def get_enhanced_message(row):
36
+ commit_id = (row['hash'], row['repo'])
37
+ if row['manual_sample']:
38
+ return manual_df.loc[commit_id]['commit_msg_end']
39
+ return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"],
40
+ diff=row["mods"])
41
+
42
+ result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1)
43
+ result['prediction'] = result.progress_apply(get_prediction_message, axis=1)
44
+ result['mods'] = result['mods'].progress_apply(json.dumps)
45
+
46
+ result.to_csv(config.DATA_FOR_LABELING_ARTIFACT)
47
+ print("Done")
48
+ return result
49
+
50
+
51
+ def main():
52
+ synthetic_start_to_end.GENERATION_ATTEMPTS = 3
53
+ df = hf_data_loader.load_full_commit_with_predictions_as_pandas()
54
+ transform(df)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ main()
generation_steps/metrics_analysis.py CHANGED
@@ -77,7 +77,7 @@ METRICS = {
77
 
78
 
79
  def attach_references(df):
80
- reference_df = hf_data_loader.load_full_commit_dataset_as_pandas().set_index(["hash", "repo"])[["reference"]]
81
  df = df.set_index(["hash", "repo"])
82
  return df.join(other=reference_df, how="left").reset_index()
83
 
 
77
 
78
 
79
  def attach_references(df):
80
+ reference_df = hf_data_loader.load_full_commit_as_pandas().set_index(["hash", "repo"])[["reference"]]
81
  df = df.set_index(["hash", "repo"])
82
  return df.join(other=reference_df, how="left").reset_index()
83
 
generation_steps/synthetic_end_to_start.py CHANGED
@@ -98,7 +98,7 @@ def transform(df):
98
 
99
 
100
  def main():
101
- df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
102
  transform(df)
103
 
104
 
 
98
 
99
 
100
  def main():
101
+ df = hf_data_loader.load_processed_rewriting_as_pandas()
102
  transform(df)
103
 
104
 
generation_steps/synthetic_start_to_end.py CHANGED
@@ -12,7 +12,7 @@ REL_DELETIONS_THRESHOLD = 0.75
12
  GENERATION_ATTEMPTS = 5
13
 
14
 
15
- def build_prompt(reference, diff):
16
  return f"""A LLM generated a commit message for the following source code changes:
17
  START OF THE SOURCE CODE CHANGES
18
  {diff}
@@ -20,7 +20,7 @@ END OF THE SOURCE CODE CHANGES
20
 
21
  Here is the message the LLM generated:
22
  START OF THE COMMIT MESSAGE
23
- {reference}
24
  END OF THE COMMIT MESSAGE
25
 
26
  This generated message is not perfect. Your task is to rewrite and improve it.
@@ -40,20 +40,20 @@ token "OUTPUT".
40
  OUTPUT"""
41
 
42
 
43
- def generate_start_msg(end_msg, diff):
44
- prompt = build_prompt(reference=end_msg, diff=diff)
45
  results = []
46
 
47
  for i in range(GENERATION_ATTEMPTS):
48
- start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
49
 
50
- stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg,
51
- annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred,
52
- end_msg))
53
  if stats["deletions"] < REL_DELETIONS_THRESHOLD:
54
- return start_msg_pred
55
  else:
56
- results.append((stats["deletions"], start_msg_pred))
57
 
58
  results.sort()
59
  return results[0][1]
@@ -62,13 +62,17 @@ def generate_start_msg(end_msg, diff):
62
  COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
63
 
64
 
65
- def transform(df):
66
- print(f"Start -> send synthesis:")
67
  print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
68
  print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
69
  print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
70
  print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
71
 
 
 
 
 
 
72
  df['start_to_end'] = False
73
 
74
  generated_data = {
@@ -80,8 +84,8 @@ def transform(df):
80
 
81
  for _, row in tqdm(df.iterrows(), total=len(df)):
82
  for i in range(GENERATION_MULTIPLIER):
83
- commit_msg_end_pred = generate_start_msg(end_msg=row["commit_msg_start"],
84
- diff=row["mods"])
85
 
86
  generated_data["commit_msg_end"].append(commit_msg_end_pred)
87
  for col in COLS_TO_KEEP:
 
12
  GENERATION_ATTEMPTS = 5
13
 
14
 
15
+ def build_prompt(prediction, diff):
16
  return f"""A LLM generated a commit message for the following source code changes:
17
  START OF THE SOURCE CODE CHANGES
18
  {diff}
 
20
 
21
  Here is the message the LLM generated:
22
  START OF THE COMMIT MESSAGE
23
+ {prediction}
24
  END OF THE COMMIT MESSAGE
25
 
26
  This generated message is not perfect. Your task is to rewrite and improve it.
 
40
  OUTPUT"""
41
 
42
 
43
+ def generate_end_msg(start_msg, diff):
44
+ prompt = build_prompt(prediction=start_msg, diff=diff)
45
  results = []
46
 
47
  for i in range(GENERATION_ATTEMPTS):
48
+ end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
49
 
50
+ stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred,
51
+ annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
52
+ end_msg_pred))
53
  if stats["deletions"] < REL_DELETIONS_THRESHOLD:
54
+ return end_msg_pred
55
  else:
56
+ results.append((stats["deletions"], end_msg_pred))
57
 
58
  results.sort()
59
  return results[0][1]
 
62
  COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
63
 
64
 
65
+ def print_config():
 
66
  print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
67
  print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
68
  print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
69
  print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
70
 
71
+
72
+ def transform(df):
73
+ print(f"Start -> send synthesis:")
74
+ print_config()
75
+
76
  df['start_to_end'] = False
77
 
78
  generated_data = {
 
84
 
85
  for _, row in tqdm(df.iterrows(), total=len(df)):
86
  for i in range(GENERATION_MULTIPLIER):
87
+ commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"],
88
+ diff=row["mods"])
89
 
90
  generated_data["commit_msg_end"].append(commit_msg_end_pred)
91
  for col in COLS_TO_KEEP: