Petr Tsvetkov
commited on
Commit
β’
6676c5a
1
Parent(s):
2d03034
Generate a dataset for the labeling app
Browse files- api_wrappers/grazie_wrapper.py +1 -1
- api_wrappers/hf_data_loader.py +22 -6
- config.py +5 -1
- generate_annotated_diffs.py +2 -2
- generate_synthetic_dataset.py +1 -1
- generation_steps/examples.py +1 -1
- generation_steps/for_labeling.py +58 -0
- generation_steps/metrics_analysis.py +1 -1
- generation_steps/synthetic_end_to_start.py +1 -1
- generation_steps/synthetic_start_to_end.py +18 -14
api_wrappers/grazie_wrapper.py
CHANGED
@@ -10,7 +10,7 @@ import config
|
|
10 |
client = GrazieApiGatewayClient(
|
11 |
grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
|
12 |
url=GrazieApiGatewayUrls.STAGING,
|
13 |
-
auth_type=AuthType.
|
14 |
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
|
15 |
)
|
16 |
|
|
|
10 |
client = GrazieApiGatewayClient(
|
11 |
grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
|
12 |
url=GrazieApiGatewayUrls.STAGING,
|
13 |
+
auth_type=AuthType.USER,
|
14 |
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
|
15 |
)
|
16 |
|
api_wrappers/hf_data_loader.py
CHANGED
@@ -3,14 +3,14 @@ from datasets import load_dataset
|
|
3 |
import config
|
4 |
|
5 |
|
6 |
-
def
|
7 |
return load_dataset(config.HF_RAW_DATASET_NAME,
|
8 |
split=config.HF_RAW_DATASET_SPLIT,
|
9 |
token=config.HF_TOKEN,
|
10 |
cache_dir=config.CACHE_DIR).to_pandas()
|
11 |
|
12 |
|
13 |
-
def
|
14 |
return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
|
15 |
name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
|
16 |
split=config.HF_FULL_COMMITS_DATASET_SPLIT,
|
@@ -18,19 +18,35 @@ def load_full_commit_dataset_as_pandas():
|
|
18 |
columns={'message': 'reference'})
|
19 |
|
20 |
|
21 |
-
def
|
22 |
-
manual_rewriting =
|
23 |
["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
|
24 |
manual_rewriting.set_index(["hash", "repo"], inplace=True)
|
25 |
|
26 |
-
mods_dataset =
|
27 |
mods_dataset.set_index(["hash", "repo"], inplace=True)
|
28 |
|
29 |
return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
|
30 |
|
31 |
|
32 |
-
def
|
33 |
return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
|
34 |
split=config.HF_SYNTHETIC_DATASET_SPLIT,
|
35 |
token=config.HF_TOKEN,
|
36 |
cache_dir=config.CACHE_DIR).to_pandas()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import config
|
4 |
|
5 |
|
6 |
+
def load_raw_rewriting_as_pandas():
|
7 |
return load_dataset(config.HF_RAW_DATASET_NAME,
|
8 |
split=config.HF_RAW_DATASET_SPLIT,
|
9 |
token=config.HF_TOKEN,
|
10 |
cache_dir=config.CACHE_DIR).to_pandas()
|
11 |
|
12 |
|
13 |
+
def load_full_commit_as_pandas():
|
14 |
return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
|
15 |
name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
|
16 |
split=config.HF_FULL_COMMITS_DATASET_SPLIT,
|
|
|
18 |
columns={'message': 'reference'})
|
19 |
|
20 |
|
21 |
+
def load_processed_rewriting_as_pandas():
|
22 |
+
manual_rewriting = load_raw_rewriting_as_pandas()[
|
23 |
["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
|
24 |
manual_rewriting.set_index(["hash", "repo"], inplace=True)
|
25 |
|
26 |
+
mods_dataset = load_full_commit_as_pandas()[["hash", "repo", "mods"]]
|
27 |
mods_dataset.set_index(["hash", "repo"], inplace=True)
|
28 |
|
29 |
return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
|
30 |
|
31 |
|
32 |
+
def load_synthetic_as_pandas():
|
33 |
return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
|
34 |
split=config.HF_SYNTHETIC_DATASET_SPLIT,
|
35 |
token=config.HF_TOKEN,
|
36 |
cache_dir=config.CACHE_DIR).to_pandas()
|
37 |
+
|
38 |
+
|
39 |
+
def load_full_commit_with_predictions_as_pandas():
|
40 |
+
full_dataset = load_full_commit_as_pandas()
|
41 |
+
predictions_dataset = load_dataset(config.HF_PREDICTIONS_DATASET_NAME,
|
42 |
+
config.HF_PREDICTIONS_DATASET_SUBNAME,
|
43 |
+
split=config.HF_PREDICTIONS_DATASET_SPLIT,
|
44 |
+
cache_dir=config.CACHE_DIR
|
45 |
+
).to_pandas().sample(frac=1, random_state=config.RANDOM_STATE
|
46 |
+
).set_index(['hash', 'repo'])[["prediction"]]
|
47 |
+
|
48 |
+
predictions_dataset = predictions_dataset[~predictions_dataset.index.duplicated(keep='first')]
|
49 |
+
|
50 |
+
dataset = full_dataset.join(other=predictions_dataset, on=('hash', 'repo'))
|
51 |
+
|
52 |
+
return dataset.reset_index()
|
config.py
CHANGED
@@ -15,6 +15,10 @@ HF_FULL_COMMITS_DATASET_NAME = "JetBrains-Research/lca-commit-message-generation
|
|
15 |
HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
|
16 |
HF_FULL_COMMITS_DATASET_SPLIT = "test"
|
17 |
|
|
|
|
|
|
|
|
|
18 |
HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
|
19 |
HF_SYNTHETIC_DATASET_SPLIT = 'train'
|
20 |
|
@@ -24,8 +28,8 @@ CACHE_DIR.mkdir(exist_ok=True)
|
|
24 |
OUTPUT_DIR = Path("output")
|
25 |
OUTPUT_DIR.mkdir(exist_ok=True)
|
26 |
|
27 |
-
|
28 |
END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
|
29 |
START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
|
30 |
SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
|
31 |
METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
|
|
|
|
15 |
HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
|
16 |
HF_FULL_COMMITS_DATASET_SPLIT = "test"
|
17 |
|
18 |
+
HF_PREDICTIONS_DATASET_NAME = "JetBrains-Research/lca-results"
|
19 |
+
HF_PREDICTIONS_DATASET_SUBNAME = "cmg_gpt_4_0613"
|
20 |
+
HF_PREDICTIONS_DATASET_SPLIT = "test"
|
21 |
+
|
22 |
HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
|
23 |
HF_SYNTHETIC_DATASET_SPLIT = 'train'
|
24 |
|
|
|
28 |
OUTPUT_DIR = Path("output")
|
29 |
OUTPUT_DIR.mkdir(exist_ok=True)
|
30 |
|
|
|
31 |
END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
|
32 |
START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
|
33 |
SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
|
34 |
METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
|
35 |
+
DATA_FOR_LABELING_ARTIFACT = OUTPUT_DIR / "data_for_labeling.csv"
|
generate_annotated_diffs.py
CHANGED
@@ -26,14 +26,14 @@ def annotated_diff_for_row(row):
|
|
26 |
|
27 |
|
28 |
def manual_data_with_annotated_diffs():
|
29 |
-
df = hf_data_loader.
|
30 |
annotated = df.apply(annotated_diff_for_row, axis=1)
|
31 |
df['annotated_diff'] = annotated
|
32 |
return df
|
33 |
|
34 |
|
35 |
def synthetic_data_with_annotated_diffs():
|
36 |
-
df = hf_data_loader.
|
37 |
annotated = df.apply(annotated_diff_for_row, axis=1)
|
38 |
df['annotated_diff'] = annotated
|
39 |
return df
|
|
|
26 |
|
27 |
|
28 |
def manual_data_with_annotated_diffs():
|
29 |
+
df = hf_data_loader.load_raw_rewriting_as_pandas()
|
30 |
annotated = df.apply(annotated_diff_for_row, axis=1)
|
31 |
df['annotated_diff'] = annotated
|
32 |
return df
|
33 |
|
34 |
|
35 |
def synthetic_data_with_annotated_diffs():
|
36 |
+
df = hf_data_loader.load_synthetic_as_pandas()
|
37 |
annotated = df.apply(annotated_diff_for_row, axis=1)
|
38 |
df['annotated_diff'] = annotated
|
39 |
return df
|
generate_synthetic_dataset.py
CHANGED
@@ -4,7 +4,7 @@ from generation_steps import synthetic_end_to_start, synthetic_start_to_end, met
|
|
4 |
|
5 |
|
6 |
def run():
|
7 |
-
df = hf_data_loader.
|
8 |
|
9 |
df = synthetic_end_to_start.transform(df)
|
10 |
df = synthetic_start_to_end.transform(df)
|
|
|
4 |
|
5 |
|
6 |
def run():
|
7 |
+
df = hf_data_loader.load_processed_rewriting_as_pandas()
|
8 |
|
9 |
df = synthetic_end_to_start.transform(df)
|
10 |
df = synthetic_start_to_end.transform(df)
|
generation_steps/examples.py
CHANGED
@@ -36,7 +36,7 @@ END OF THE IMPROVED COMMIT MESSAGE
|
|
36 |
END OF THE EXAMPLE"""
|
37 |
|
38 |
|
39 |
-
manual_df = hf_data_loader.
|
40 |
manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
|
41 |
|
42 |
|
|
|
36 |
END OF THE EXAMPLE"""
|
37 |
|
38 |
|
39 |
+
manual_df = hf_data_loader.load_raw_rewriting_as_pandas()[['commit_msg_start', 'commit_msg_end']]
|
40 |
manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
|
41 |
|
42 |
|
generation_steps/for_labeling.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import config
|
6 |
+
from api_wrappers import hf_data_loader
|
7 |
+
from generation_steps import synthetic_start_to_end
|
8 |
+
|
9 |
+
|
10 |
+
def transform(df):
|
11 |
+
print(f"Generating data for labeling:")
|
12 |
+
synthetic_start_to_end.print_config()
|
13 |
+
tqdm.pandas()
|
14 |
+
|
15 |
+
manual_df = hf_data_loader.load_raw_rewriting_as_pandas()
|
16 |
+
|
17 |
+
manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE
|
18 |
+
).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]]
|
19 |
+
|
20 |
+
manual_df = manual_df[~manual_df.index.duplicated(keep='first')]
|
21 |
+
|
22 |
+
def get_is_manually_rewritten(row):
|
23 |
+
commit_id = (row['hash'], row['repo'])
|
24 |
+
return commit_id in manual_df.index
|
25 |
+
|
26 |
+
result = df
|
27 |
+
result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1)
|
28 |
+
|
29 |
+
def get_prediction_message(row):
|
30 |
+
commit_id = (row['hash'], row['repo'])
|
31 |
+
if row['manual_sample']:
|
32 |
+
return manual_df.loc[commit_id]['commit_msg_start']
|
33 |
+
return row['prediction']
|
34 |
+
|
35 |
+
def get_enhanced_message(row):
|
36 |
+
commit_id = (row['hash'], row['repo'])
|
37 |
+
if row['manual_sample']:
|
38 |
+
return manual_df.loc[commit_id]['commit_msg_end']
|
39 |
+
return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"],
|
40 |
+
diff=row["mods"])
|
41 |
+
|
42 |
+
result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1)
|
43 |
+
result['prediction'] = result.progress_apply(get_prediction_message, axis=1)
|
44 |
+
result['mods'] = result['mods'].progress_apply(json.dumps)
|
45 |
+
|
46 |
+
result.to_csv(config.DATA_FOR_LABELING_ARTIFACT)
|
47 |
+
print("Done")
|
48 |
+
return result
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
synthetic_start_to_end.GENERATION_ATTEMPTS = 3
|
53 |
+
df = hf_data_loader.load_full_commit_with_predictions_as_pandas()
|
54 |
+
transform(df)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
main()
|
generation_steps/metrics_analysis.py
CHANGED
@@ -77,7 +77,7 @@ METRICS = {
|
|
77 |
|
78 |
|
79 |
def attach_references(df):
|
80 |
-
reference_df = hf_data_loader.
|
81 |
df = df.set_index(["hash", "repo"])
|
82 |
return df.join(other=reference_df, how="left").reset_index()
|
83 |
|
|
|
77 |
|
78 |
|
79 |
def attach_references(df):
|
80 |
+
reference_df = hf_data_loader.load_full_commit_as_pandas().set_index(["hash", "repo"])[["reference"]]
|
81 |
df = df.set_index(["hash", "repo"])
|
82 |
return df.join(other=reference_df, how="left").reset_index()
|
83 |
|
generation_steps/synthetic_end_to_start.py
CHANGED
@@ -98,7 +98,7 @@ def transform(df):
|
|
98 |
|
99 |
|
100 |
def main():
|
101 |
-
df = hf_data_loader.
|
102 |
transform(df)
|
103 |
|
104 |
|
|
|
98 |
|
99 |
|
100 |
def main():
|
101 |
+
df = hf_data_loader.load_processed_rewriting_as_pandas()
|
102 |
transform(df)
|
103 |
|
104 |
|
generation_steps/synthetic_start_to_end.py
CHANGED
@@ -12,7 +12,7 @@ REL_DELETIONS_THRESHOLD = 0.75
|
|
12 |
GENERATION_ATTEMPTS = 5
|
13 |
|
14 |
|
15 |
-
def build_prompt(
|
16 |
return f"""A LLM generated a commit message for the following source code changes:
|
17 |
START OF THE SOURCE CODE CHANGES
|
18 |
{diff}
|
@@ -20,7 +20,7 @@ END OF THE SOURCE CODE CHANGES
|
|
20 |
|
21 |
Here is the message the LLM generated:
|
22 |
START OF THE COMMIT MESSAGE
|
23 |
-
{
|
24 |
END OF THE COMMIT MESSAGE
|
25 |
|
26 |
This generated message is not perfect. Your task is to rewrite and improve it.
|
@@ -40,20 +40,20 @@ token "OUTPUT".
|
|
40 |
OUTPUT"""
|
41 |
|
42 |
|
43 |
-
def
|
44 |
-
prompt = build_prompt(
|
45 |
results = []
|
46 |
|
47 |
for i in range(GENERATION_ATTEMPTS):
|
48 |
-
|
49 |
|
50 |
-
stats = statistics.get_statistics(start_msg=
|
51 |
-
annotated_msg=generate_annotated_diffs.get_annotated_diff(
|
52 |
-
|
53 |
if stats["deletions"] < REL_DELETIONS_THRESHOLD:
|
54 |
-
return
|
55 |
else:
|
56 |
-
results.append((stats["deletions"],
|
57 |
|
58 |
results.sort()
|
59 |
return results[0][1]
|
@@ -62,13 +62,17 @@ def generate_start_msg(end_msg, diff):
|
|
62 |
COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
|
63 |
|
64 |
|
65 |
-
def
|
66 |
-
print(f"Start -> send synthesis:")
|
67 |
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
|
68 |
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
|
69 |
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
|
70 |
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
|
71 |
|
|
|
|
|
|
|
|
|
|
|
72 |
df['start_to_end'] = False
|
73 |
|
74 |
generated_data = {
|
@@ -80,8 +84,8 @@ def transform(df):
|
|
80 |
|
81 |
for _, row in tqdm(df.iterrows(), total=len(df)):
|
82 |
for i in range(GENERATION_MULTIPLIER):
|
83 |
-
commit_msg_end_pred =
|
84 |
-
|
85 |
|
86 |
generated_data["commit_msg_end"].append(commit_msg_end_pred)
|
87 |
for col in COLS_TO_KEEP:
|
|
|
12 |
GENERATION_ATTEMPTS = 5
|
13 |
|
14 |
|
15 |
+
def build_prompt(prediction, diff):
|
16 |
return f"""A LLM generated a commit message for the following source code changes:
|
17 |
START OF THE SOURCE CODE CHANGES
|
18 |
{diff}
|
|
|
20 |
|
21 |
Here is the message the LLM generated:
|
22 |
START OF THE COMMIT MESSAGE
|
23 |
+
{prediction}
|
24 |
END OF THE COMMIT MESSAGE
|
25 |
|
26 |
This generated message is not perfect. Your task is to rewrite and improve it.
|
|
|
40 |
OUTPUT"""
|
41 |
|
42 |
|
43 |
+
def generate_end_msg(start_msg, diff):
|
44 |
+
prompt = build_prompt(prediction=start_msg, diff=diff)
|
45 |
results = []
|
46 |
|
47 |
for i in range(GENERATION_ATTEMPTS):
|
48 |
+
end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
|
49 |
|
50 |
+
stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred,
|
51 |
+
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
|
52 |
+
end_msg_pred))
|
53 |
if stats["deletions"] < REL_DELETIONS_THRESHOLD:
|
54 |
+
return end_msg_pred
|
55 |
else:
|
56 |
+
results.append((stats["deletions"], end_msg_pred))
|
57 |
|
58 |
results.sort()
|
59 |
return results[0][1]
|
|
|
62 |
COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
|
63 |
|
64 |
|
65 |
+
def print_config():
|
|
|
66 |
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
|
67 |
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
|
68 |
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
|
69 |
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
|
70 |
|
71 |
+
|
72 |
+
def transform(df):
|
73 |
+
print(f"Start -> send synthesis:")
|
74 |
+
print_config()
|
75 |
+
|
76 |
df['start_to_end'] = False
|
77 |
|
78 |
generated_data = {
|
|
|
84 |
|
85 |
for _, row in tqdm(df.iterrows(), total=len(df)):
|
86 |
for i in range(GENERATION_MULTIPLIER):
|
87 |
+
commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"],
|
88 |
+
diff=row["mods"])
|
89 |
|
90 |
generated_data["commit_msg_end"].append(commit_msg_end_pred)
|
91 |
for col in COLS_TO_KEEP:
|