Petr Tsvetkov commited on
Commit
0c136d8
β€’
1 Parent(s): ba39281

Synthetic dataset generation

Browse files
change_visualizer.py CHANGED
@@ -2,28 +2,41 @@ import gradio as gr
2
 
3
  import generate_annotated_diffs
4
 
5
- df = generate_annotated_diffs.data_with_annotated_diffs()
6
- n_diffs = len(df)
7
 
8
 
9
- def update_view(diff_idx):
10
  diff_idx -= 1
11
- return df.iloc[diff_idx]['annotated_diff'], df.iloc[diff_idx]['commit_msg_start'], df.iloc[diff_idx][
12
- 'commit_msg_end'], df.iloc[diff_idx][
13
- 'session'], f"https://github.com/{df.iloc[diff_idx]['repo']}/commit/{df.iloc[diff_idx]['hash']}"
 
14
 
15
 
16
  if __name__ == '__main__':
17
  with gr.Blocks(theme=gr.themes.Soft()) as application:
18
- slider = gr.Slider(minimum=1, maximum=n_diffs, step=1, value=1, label=f"Sample number (total: {n_diffs})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- diff_view = gr.Highlightedtext(combine_adjacent=True, color_map={'+': "green", '-': "red"})
21
- start_view = gr.Textbox(interactive=False, label="Start message", container=True)
22
- end_view = gr.Textbox(interactive=False, label="End message", container=True)
23
- session_view = gr.Textbox(interactive=False, label="Session", container=True)
24
- link_view = gr.Markdown()
25
-
26
- slider.change(update_view, inputs=slider, outputs=[diff_view, start_view, end_view, session_view, link_view])
27
-
28
- application.load(update_view, inputs=slider, outputs=[diff_view, start_view, end_view, session_view, link_view])
29
  application.launch()
 
2
 
3
  import generate_annotated_diffs
4
 
5
+ df_manual = generate_annotated_diffs.manual_data_with_annotated_diffs()
6
+ n_diffs_manual = len(df_manual)
7
 
8
 
9
+ def update_manual_view(diff_idx):
10
  diff_idx -= 1
11
+ return df_manual.iloc[diff_idx]['annotated_diff'], df_manual.iloc[diff_idx]['commit_msg_start'], \
12
+ df_manual.iloc[diff_idx][
13
+ 'commit_msg_end'], df_manual.iloc[diff_idx][
14
+ 'session'], f"https://github.com/{df_manual.iloc[diff_idx]['repo']}/commit/{df_manual.iloc[diff_idx]['hash']}"
15
 
16
 
17
  if __name__ == '__main__':
18
  with gr.Blocks(theme=gr.themes.Soft()) as application:
19
+ with gr.Tab("Manual"):
20
+ slider_manual = gr.Slider(minimum=1, maximum=n_diffs_manual, step=1, value=1,
21
+ label=f"Sample number (total: {n_diffs_manual})")
22
+
23
+ diff_view_manual = gr.Highlightedtext(combine_adjacent=True, color_map={'+': "green", '-': "red"})
24
+ start_view_manual = gr.Textbox(interactive=False, label="Start message", container=True)
25
+ end_view_manual = gr.Textbox(interactive=False, label="End message", container=True)
26
+ session_view_manual = gr.Textbox(interactive=False, label="Session", container=True)
27
+ link_view_manual = gr.Markdown()
28
+ view_manual = [
29
+ diff_view_manual,
30
+ start_view_manual,
31
+ end_view_manual,
32
+ session_view_manual,
33
+ link_view_manual
34
+ ]
35
+
36
+ slider_manual.change(update_manual_view, inputs=slider_manual,
37
+ outputs=view_manual)
38
+
39
+ application.load(update_manual_view, inputs=slider_manual,
40
+ outputs=view_manual)
41
 
 
 
 
 
 
 
 
 
 
42
  application.launch()
config.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  from pathlib import Path
3
 
 
 
4
  HF_TOKEN = os.environ.get('HF_TOKEN')
5
  HF_RAW_DATASET_NAME = "petrtsv-jb/commit-msg-rewriting"
6
  HF_RAW_DATASET_SPLIT = 'train'
@@ -11,4 +13,4 @@ CACHE_DIR.mkdir(exist_ok=True)
11
  OUTPUT_DIR = Path("output")
12
  OUTPUT_DIR.mkdir(exist_ok=True)
13
 
14
- ANNOTATED_DIFFS_ARTIFACT = OUTPUT_DIR / "annotated_diffs.csv"
 
1
  import os
2
  from pathlib import Path
3
 
4
+ GRAZIE_API_JWT_TOKEN = os.environ.get("GRAZIE_API_JWT_TOKEN")
5
+
6
  HF_TOKEN = os.environ.get('HF_TOKEN')
7
  HF_RAW_DATASET_NAME = "petrtsv-jb/commit-msg-rewriting"
8
  HF_RAW_DATASET_SPLIT = 'train'
 
13
  OUTPUT_DIR = Path("output")
14
  OUTPUT_DIR.mkdir(exist_ok=True)
15
 
16
+ SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
generate_annotated_diffs.py CHANGED
@@ -1,49 +1,8 @@
1
- from datetime import datetime
2
-
3
  import diff_match_patch as dmp_module
4
 
5
  import hf_data_loader
6
 
7
 
8
- def group_changes(changes):
9
- groups = {}
10
- for change in changes:
11
- group = datetime.fromisoformat(change['ts'])
12
- if group not in groups:
13
- groups[group] = []
14
- groups[group].append(change)
15
-
16
- grouped_changes = []
17
- for group in sorted(groups.keys()):
18
- groups[group].sort(key=lambda x: x['p'])
19
- grouped_changes.append(groups[group])
20
-
21
- return grouped_changes
22
-
23
-
24
- def fill_in_annotation_gaps(annotated_text):
25
- seg_start = None
26
- seg_type = None
27
-
28
- for i, e in enumerate(annotated_text):
29
- if e[1] is None:
30
- continue
31
-
32
- if seg_type is None:
33
- seg_start = i
34
- elif seg_type != e[1]:
35
- for j in range(seg_start, i):
36
- annotated_text[j][1] = seg_type
37
- seg_start = i
38
- seg_type = e[1]
39
-
40
- if seg_start is not None:
41
- for j in range(seg_start, len(annotated_text)):
42
- annotated_text[j][1] = seg_type
43
-
44
- return annotated_text
45
-
46
-
47
  def get_annotated_diff(start_text, end_text):
48
  dmp = dmp_module.diff_match_patch()
49
  dmp_mapping = {
@@ -60,14 +19,20 @@ def get_annotated_diff(start_text, end_text):
60
  return result
61
 
62
 
63
- def annotated_diff_for_row(row):
64
  start = row['commit_msg_start']
65
  end = row['commit_msg_end']
66
  return get_annotated_diff(start, end)
67
 
68
 
69
- def data_with_annotated_diffs():
70
- df = hf_data_loader.load_raw_dataset_as_pandas()
71
- annotated = df.apply(annotated_diff_for_row, axis=1)
 
 
 
 
 
 
72
  df['annotated_diff'] = annotated
73
  return df
 
 
 
1
  import diff_match_patch as dmp_module
2
 
3
  import hf_data_loader
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def get_annotated_diff(start_text, end_text):
7
  dmp = dmp_module.diff_match_patch()
8
  dmp_mapping = {
 
19
  return result
20
 
21
 
22
+ def annotated_diff_for_row_manual_df(row):
23
  start = row['commit_msg_start']
24
  end = row['commit_msg_end']
25
  return get_annotated_diff(start, end)
26
 
27
 
28
+ def annotated_diff_for_row_synthetic_df(row):
29
+ start = row['initial_msg_pred']
30
+ end = row['reference']
31
+ return get_annotated_diff(start, end)
32
+
33
+
34
+ def manual_data_with_annotated_diffs():
35
+ df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()
36
+ annotated = df.apply(annotated_diff_for_row_manual_df, axis=1)
37
  df['annotated_diff'] = annotated
38
  return df
generate_synthetic_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grazie.api.client.chat.prompt import ChatPrompt
2
+ from grazie.api.client.endpoints import GrazieApiGatewayUrls
3
+ from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
4
+ from grazie.api.client.profiles import LLMProfile
5
+ from tqdm import tqdm
6
+
7
+ import config
8
+ import hf_data_loader
9
+
10
+ client = GrazieApiGatewayClient(
11
+ grazie_agent=GrazieAgent(name="commit-rewriting-summary-generation", version="dev"),
12
+ url=GrazieApiGatewayUrls.STAGING,
13
+ auth_type=AuthType.SERVICE,
14
+ grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
15
+ )
16
+
17
+
18
+ def build_prompt(reference, diff):
19
+ return f"""A software developer uses a LLM to generate commit messages.
20
+
21
+ They generated a commit message for the following source code changes:
22
+ START OF THE SOURCE CODE CHANGES
23
+ {diff}
24
+ END OF THE SOURCE CODE CHANGES
25
+
26
+ After generating the commit message the developer understands that it is not perfect. After making dome changes,
27
+ they come up with an edited version of the message. Here is this edited message:
28
+ START OF THE COMMIT MESSAGE
29
+ {reference}
30
+ END OF THE COMMIT MESSAGE
31
+
32
+ Your task is to print the initial, LLM-generated commit message. Print only the initial commit message's text after the
33
+ token "OUTPUT".
34
+
35
+ OUTPUT"""
36
+
37
+
38
+ def generate_prompt_for_row(row):
39
+ reference = row['reference']
40
+ diff = row['mods']
41
+ return build_prompt(reference, diff)
42
+
43
+
44
+ def generate_initial_msg(prompt):
45
+ commit_msg = client.chat(
46
+ chat=ChatPrompt()
47
+ .add_system("You are a helpful assistant.")
48
+ .add_user(prompt),
49
+ profile=LLMProfile("gpt-4-1106-preview")
50
+ ).content
51
+
52
+ return commit_msg
53
+
54
+
55
+ def generate_synthetic_dataset():
56
+ df = hf_data_loader.load_full_commit_dataset_as_pandas()
57
+ df['initial_msg_prompt'] = df.apply(generate_prompt_for_row, axis=1)
58
+ initial_messages_pred = []
59
+
60
+ for prompt in tqdm(df['initial_msg_prompt']):
61
+ initial_messages_pred.append(generate_initial_msg(prompt))
62
+
63
+ df['initial_msg_pred'] = initial_messages_pred
64
+
65
+ df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ generate_synthetic_dataset()
hf_data_loader.py CHANGED
@@ -3,11 +3,16 @@ from datasets import load_dataset
3
  import config
4
 
5
 
6
- def load_raw_dataset_as_pandas():
7
  return load_dataset(config.HF_RAW_DATASET_NAME,
8
  split=config.HF_RAW_DATASET_SPLIT,
9
  token=config.HF_TOKEN,
10
  cache_dir=config.CACHE_DIR).to_pandas()
11
 
12
 
13
- load_raw_dataset_as_pandas()
 
 
 
 
 
 
3
  import config
4
 
5
 
6
+ def load_raw_rewriting_dataset_as_pandas():
7
  return load_dataset(config.HF_RAW_DATASET_NAME,
8
  split=config.HF_RAW_DATASET_SPLIT,
9
  token=config.HF_TOKEN,
10
  cache_dir=config.CACHE_DIR).to_pandas()
11
 
12
 
13
+ def load_full_commit_dataset_as_pandas():
14
+ return load_dataset("JetBrains-Research/lca-commit-message-generation",
15
+ "commitchronicle-py-long",
16
+ split="test",
17
+ cache_dir=config.CACHE_DIR).to_pandas().rename(
18
+ columns={'message': 'reference'})