Petr Tsvetkov commited on
Commit
a8a595d
1 Parent(s): e2a35c0

- New version of the end->start synthetics samples generation

Browse files
api_wrappers/__init__.py ADDED
File without changes
api_wrappers/grazie_wrapper.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from grazie.api.client.chat.prompt import ChatPrompt
4
+ from grazie.api.client.endpoints import GrazieApiGatewayUrls
5
+ from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
6
+ from grazie.api.client.profiles import LLMProfile
7
+
8
+ import config
9
+
10
+ client = GrazieApiGatewayClient(
11
+ grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
12
+ url=GrazieApiGatewayUrls.STAGING,
13
+ auth_type=AuthType.SERVICE,
14
+ grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
15
+ )
16
+
17
+
18
+ def generate_for_prompt(prompt):
19
+ output = None
20
+
21
+ while output is None:
22
+ try:
23
+ output = output = client.chat(
24
+ chat=ChatPrompt()
25
+ .add_system("You are a helpful assistant.")
26
+ .add_user(prompt),
27
+ profile=LLMProfile("gpt-4-1106-preview")
28
+ ).content
29
+ except:
30
+ time.sleep(config.GRAZIE_TIMEOUT_SEC)
31
+
32
+ assert output is not None
33
+
34
+ return output
hf_data_loader.py → api_wrappers/hf_data_loader.py RENAMED
@@ -18,6 +18,16 @@ def load_full_commit_dataset_as_pandas():
18
  columns={'message': 'reference'})
19
 
20
 
 
 
 
 
 
 
 
 
 
 
21
  def load_synthetic_dataset_as_pandas():
22
  return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
23
  split=config.HF_SYNTHETIC_DATASET_SPLIT,
 
18
  columns={'message': 'reference'})
19
 
20
 
21
+ def load_processed_rewriting_dataset_as_pandas():
22
+ manual_rewriting = load_raw_rewriting_dataset_as_pandas()[["hash", "repo", "commit_msg_start", "commit_msg_end"]]
23
+ manual_rewriting.set_index(["hash", "repo"], inplace=True)
24
+
25
+ mods_dataset = load_full_commit_dataset_as_pandas()[["hash", "repo", "mods"]]
26
+ mods_dataset.set_index(["hash", "repo"], inplace=True)
27
+
28
+ return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
29
+
30
+
31
  def load_synthetic_dataset_as_pandas():
32
  return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
33
  split=config.HF_SYNTHETIC_DATASET_SPLIT,
config.py CHANGED
@@ -1,7 +1,10 @@
1
  import os
2
  from pathlib import Path
3
 
 
 
4
  GRAZIE_API_JWT_TOKEN = os.environ.get("GRAZIE_API_JWT_TOKEN")
 
5
 
6
  HF_TOKEN = os.environ.get('HF_TOKEN')
7
 
 
1
  import os
2
  from pathlib import Path
3
 
4
+ RANDOM_STATE = 42
5
+
6
  GRAZIE_API_JWT_TOKEN = os.environ.get("GRAZIE_API_JWT_TOKEN")
7
+ GRAZIE_TIMEOUT_SEC = 1.0
8
 
9
  HF_TOKEN = os.environ.get('HF_TOKEN')
10
 
generate_annotated_diffs.py CHANGED
@@ -1,6 +1,6 @@
1
  import diff_match_patch as dmp_module
2
 
3
- import hf_data_loader
4
 
5
 
6
  def get_annotated_diff(start_text, end_text):
@@ -19,27 +19,21 @@ def get_annotated_diff(start_text, end_text):
19
  return result
20
 
21
 
22
- def annotated_diff_for_row_manual_df(row):
23
  start = row['commit_msg_start']
24
  end = row['commit_msg_end']
25
  return get_annotated_diff(start, end)
26
 
27
 
28
- def annotated_diff_for_row_synthetic_df(row):
29
- start = row['initial_msg_pred']
30
- end = row['reference']
31
- return get_annotated_diff(start, end)
32
-
33
-
34
  def manual_data_with_annotated_diffs():
35
  df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()
36
- annotated = df.apply(annotated_diff_for_row_manual_df, axis=1)
37
  df['annotated_diff'] = annotated
38
  return df
39
 
40
 
41
  def synthetic_data_with_annotated_diffs():
42
  df = hf_data_loader.load_synthetic_dataset_as_pandas()
43
- annotated = df.apply(annotated_diff_for_row_synthetic_df, axis=1)
44
  df['annotated_diff'] = annotated
45
  return df
 
1
  import diff_match_patch as dmp_module
2
 
3
+ from api_wrappers import hf_data_loader
4
 
5
 
6
  def get_annotated_diff(start_text, end_text):
 
19
  return result
20
 
21
 
22
+ def annotated_diff_for_row(row):
23
  start = row['commit_msg_start']
24
  end = row['commit_msg_end']
25
  return get_annotated_diff(start, end)
26
 
27
 
 
 
 
 
 
 
28
  def manual_data_with_annotated_diffs():
29
  df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()
30
+ annotated = df.apply(annotated_diff_for_row, axis=1)
31
  df['annotated_diff'] = annotated
32
  return df
33
 
34
 
35
  def synthetic_data_with_annotated_diffs():
36
  df = hf_data_loader.load_synthetic_dataset_as_pandas()
37
+ annotated = df.apply(annotated_diff_for_row, axis=1)
38
  df['annotated_diff'] = annotated
39
  return df
generate_synthetic_dataset.py CHANGED
@@ -1,114 +1,20 @@
1
- import time
2
-
3
- from grazie.api.client.chat.prompt import ChatPrompt
4
- from grazie.api.client.endpoints import GrazieApiGatewayUrls
5
- from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
6
- from grazie.api.client.profiles import LLMProfile
7
- from tqdm import tqdm
8
-
9
  import config
10
- import hf_data_loader
11
-
12
- client = GrazieApiGatewayClient(
13
- grazie_agent=GrazieAgent(name="commit-rewriting-summary-generation", version="dev"),
14
- url=GrazieApiGatewayUrls.STAGING,
15
- auth_type=AuthType.SERVICE,
16
- grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
17
- )
18
-
19
-
20
- def get_example_prompt(start_msg, end_msg):
21
- return f"""START OF THE EXAMPLE
22
-
23
- For following the edited message:
24
- START OF THE EDITED COMMIT MESSAGE
25
- {end_msg}
26
- END OF THE EDITED COMMIT MESSAGE
27
-
28
- You would output the following initial commit message:
29
- START OF THE INITIAL COMMIT MESSAGE
30
- {start_msg}
31
- END OF THE INITIAL COMMIT MESSAGE
32
-
33
- END OF THE EXAMPLE"""
34
-
35
-
36
- def generate_examples():
37
- manual_df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()[['commit_msg_start', 'commit_msg_end']]
38
- examples = [
39
- get_example_prompt(row['commit_msg_start'], row['commit_msg_end'])
40
- for _, row in manual_df.iterrows()
41
- ]
42
-
43
- return "\n".join(examples)
44
-
45
-
46
- EXAMPLES = generate_examples()
47
-
48
-
49
- def build_prompt(reference, diff):
50
- return f"""A software developer uses a LLM to generate commit messages.
51
-
52
- They generated a commit message for the following source code changes:
53
- START OF THE SOURCE CODE CHANGES
54
- {diff}
55
- END OF THE SOURCE CODE CHANGES
56
-
57
- After generating the commit message the developer understands that it is not perfect. After making dome changes,
58
- they come up with an edited version of the message. Here is this edited message:
59
- START OF THE COMMIT MESSAGE
60
- {reference}
61
- END OF THE COMMIT MESSAGE
62
-
63
- Your task is to print the initial, LLM-generated commit message. Here are some examples of what you should output:
64
- START OF THE EXAMPLES LIST
65
- {EXAMPLES}
66
- END OF THE EXAMPLES LIST
67
-
68
- Print only the initial commit message's text after the
69
- token "OUTPUT".
70
-
71
- OUTPUT"""
72
-
73
-
74
- def generate_prompt_for_row(row):
75
- reference = row['reference']
76
- diff = row['mods']
77
- return build_prompt(reference, diff)
78
-
79
-
80
- def generate_initial_msg(prompt):
81
- commit_msg = client.chat(
82
- chat=ChatPrompt()
83
- .add_system("You are a helpful assistant.")
84
- .add_user(prompt),
85
- profile=LLMProfile("gpt-4-1106-preview")
86
- ).content
87
-
88
- return commit_msg
89
-
90
-
91
- def generate_synthetic_dataset():
92
- df = hf_data_loader.load_full_commit_dataset_as_pandas()
93
- df['initial_msg_prompt'] = df.apply(generate_prompt_for_row, axis=1)
94
- initial_messages_pred = []
95
-
96
- for i, prompt in enumerate(tqdm(df['initial_msg_ prompt'])):
97
- output = None
98
 
99
- while output is None:
100
- try:
101
- output = generate_initial_msg(prompt)
102
- except:
103
- time.sleep(0.5)
104
 
105
- assert output is not None
106
- initial_messages_pred.append(output)
107
 
108
- df['initial_msg_pred'] = initial_messages_pred
 
 
 
 
 
109
 
110
  df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
111
 
112
 
113
  if __name__ == '__main__':
114
- generate_synthetic_dataset()
 
 
 
 
 
 
 
 
 
1
  import config
2
+ from api_wrappers import hf_data_loader
3
+ from generation_steps import synthetic_end_to_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
5
 
6
+ def run():
7
+ df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
8
 
9
+ print(f"End -> start synthesis:")
10
+ print(f"GENERATION_MULTIPLIER = {synthetic_end_to_start.GENERATION_MULTIPLIER}")
11
+ print(f"REL_INSERTIONS_THRESHOLD = {synthetic_end_to_start.REL_INSERTIONS_THRESHOLD}")
12
+ print(f"GENERATION_ATTEMPTS = {synthetic_end_to_start.GENERATION_ATTEMPTS}")
13
+ df = synthetic_end_to_start.transform(df)
14
+ print("Done")
15
 
16
  df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
17
 
18
 
19
  if __name__ == '__main__':
20
+ run()
generation_steps/__init__.py ADDED
File without changes
generation_steps/synthetic_end_to_start.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+
4
+ import config
5
+ import generate_annotated_diffs
6
+ from api_wrappers import grazie_wrapper, hf_data_loader
7
+ import statistics
8
+
9
+ N_EXAMPLES = 5
10
+ GENERATION_MULTIPLIER = 2
11
+ REL_INSERTIONS_THRESHOLD = 0.6
12
+ GENERATION_ATTEMPTS = 5
13
+
14
+
15
+ def get_example_prompt(start_msg, end_msg):
16
+ return f"""START OF THE EXAMPLE
17
+
18
+ For following the edited message:
19
+ START OF THE EDITED COMMIT MESSAGE
20
+ {end_msg}
21
+ END OF THE EDITED COMMIT MESSAGE
22
+
23
+ You would output the following initial commit message:
24
+ START OF THE INITIAL COMMIT MESSAGE
25
+ {start_msg}
26
+ END OF THE INITIAL COMMIT MESSAGE
27
+
28
+ END OF THE EXAMPLE"""
29
+
30
+
31
+ def generate_examples():
32
+ manual_df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()[['commit_msg_start', 'commit_msg_end']]
33
+ manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
34
+ examples = [
35
+ get_example_prompt(row['commit_msg_start'], row['commit_msg_end'])
36
+ for _, row in manual_df.iterrows()
37
+ ]
38
+
39
+ return "\n".join(examples)
40
+
41
+
42
+ EXAMPLES = generate_examples()
43
+
44
+
45
+ def build_prompt(reference, diff):
46
+ return f"""A software developer uses a LLM to generate commit messages.
47
+
48
+ They generated a commit message for the following source code changes:
49
+ START OF THE SOURCE CODE CHANGES
50
+ {diff}
51
+ END OF THE SOURCE CODE CHANGES
52
+
53
+ After generating the commit message the developer understands that it is not perfect. After making dome changes,
54
+ they come up with an edited version of the message. Here is this edited message:
55
+ START OF THE COMMIT MESSAGE
56
+ {reference}
57
+ END OF THE COMMIT MESSAGE
58
+
59
+ Your task is to print the initial, LLM-generated commit message.
60
+ The message you print must share some fragments with the edited message.
61
+ Here are some examples of what you should output:
62
+ START OF THE EXAMPLES LIST
63
+ {EXAMPLES}
64
+ END OF THE EXAMPLES LIST
65
+
66
+
67
+ Print only the initial commit message's text after the
68
+ token "OUTPUT".
69
+
70
+ OUTPUT"""
71
+
72
+
73
+ def generate_start_msg(end_msg, diff):
74
+ prompt = build_prompt(reference=end_msg, diff=diff)
75
+ results = []
76
+
77
+ for i in range(GENERATION_ATTEMPTS):
78
+ start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
79
+
80
+ stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg,
81
+ annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred,
82
+ end_msg))
83
+ if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
84
+ return start_msg_pred
85
+ else:
86
+ results.append((stats["insertions"], start_msg_pred))
87
+
88
+ results.sort()
89
+ return results[0][1]
90
+
91
+
92
+ def transform(df):
93
+ df['end_to_start'] = False
94
+
95
+ generated_data = {
96
+ "hash": [],
97
+ "repo": [],
98
+ "commit_msg_start": [],
99
+ "commit_msg_end": [],
100
+ "mods": []
101
+ }
102
+
103
+ for _, row in tqdm(df.iterrows(), total=len(df)):
104
+ for i in range(GENERATION_MULTIPLIER):
105
+ commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"],
106
+ diff=row["mods"])
107
+ generated_data["hash"].append(row["hash"])
108
+ generated_data["repo"].append(row["repo"])
109
+ generated_data["commit_msg_start"].append(commit_msg_start_pred)
110
+ generated_data["commit_msg_end"].append(row["commit_msg_end"])
111
+ generated_data["mods"].append(row["mods"])
112
+
113
+ generated_df = pd.DataFrame.from_dict(generated_data)
114
+ generated_df['end_to_start'] = True
115
+
116
+ return pd.concat([df, generated_df], ignore_index=True)
statistics.py CHANGED
@@ -2,35 +2,34 @@ import numpy as np
2
  import pandas as pd
3
 
4
 
5
- def get_statistics_for_df(df: pd.DataFrame, start_col, end_col, annotated_col):
6
- relative_deletions = []
7
- relative_insertions = []
8
- relative_changes = []
9
-
10
- for _, row in df.iterrows():
11
- sum_deletions = 0
12
- sum_insertions = 0
13
- for text, change_type in row[annotated_col]:
14
- if change_type == '-':
15
- sum_deletions += len(text)
16
- elif change_type == '+':
17
- sum_insertions += len(text)
18
-
19
- sum_changes = sum_deletions + sum_insertions
20
- end_length = len(row[end_col])
21
- start_length = len(row[start_col])
22
-
23
- relative_deletions.append(sum_deletions / start_length)
24
- relative_insertions.append(sum_insertions / end_length)
25
- relative_changes.append(sum_changes / end_length)
26
 
27
  return {
28
- "deletions": np.asarray(relative_deletions),
29
- "insertions": np.asarray(relative_insertions),
30
- "changes": np.asarray(relative_changes)
31
  }
32
 
33
 
 
 
 
 
 
 
 
 
34
  def get_statistics_for_manual_df(df):
35
  return get_statistics_for_df(df, start_col="commit_msg_start", end_col='commit_msg_end',
36
  annotated_col='annotated_diff')
 
2
  import pandas as pd
3
 
4
 
5
+ def get_statistics(start_msg, end_msg, annotated_msg):
6
+ sum_deletions = 0
7
+ sum_insertions = 0
8
+ for text, change_type in annotated_msg:
9
+ if change_type == '-':
10
+ sum_deletions += len(text)
11
+ elif change_type == '+':
12
+ sum_insertions += len(text)
13
+
14
+ sum_changes = sum_deletions + sum_insertions
15
+ end_length = len(end_msg)
16
+ start_length = len(start_msg)
 
 
 
 
 
 
 
 
 
17
 
18
  return {
19
+ "deletions": sum_deletions / start_length,
20
+ "insertions": sum_insertions / end_length,
21
+ "changes": sum_changes / end_length
22
  }
23
 
24
 
25
+ def get_statistics_for_df(df: pd.DataFrame, start_col, end_col, annotated_col):
26
+ stats = [get_statistics(row[start_col], row[end_col], row[annotated_col]) for _, row in df.iterrows()]
27
+
28
+ assert len(stats) > 0
29
+
30
+ return {stat_name: np.asarray([e[stat_name] for e in stats]) for stat_name in stats[0]}
31
+
32
+
33
  def get_statistics_for_manual_df(df):
34
  return get_statistics_for_df(df, start_col="commit_msg_start", end_col='commit_msg_end',
35
  annotated_col='annotated_diff')