Petr Tsvetkov
Add checkpoints
e027012
raw
history blame
3.25 kB
import pandas as pd
from tqdm import tqdm
import config
import generate_annotated_diffs
import statistics
from api_wrappers import grazie_wrapper
from generation_steps import examples
GENERATION_MULTIPLIER = 3
REL_DELETIONS_THRESHOLD = 0.75
GENERATION_ATTEMPTS = 5
def build_prompt(reference, diff):
return f"""A LLM generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES
Here is the message the LLM generated:
START OF THE COMMIT MESSAGE
{reference}
END OF THE COMMIT MESSAGE
This generated message is not perfect. Your task is to rewrite and improve it.
You have to simulate a human software developer who manually rewrites the LLM-generated commit message,
so the message you print must share some fragments with the generated message.
Your message should be concise.
Follow the Conventional Commits guidelines.
Here are some examples of what you should output:
START OF THE EXAMPLES LIST
{examples.EXAMPLES_START_TO_END}
END OF THE EXAMPLES LIST
Print only the improved commit message's text after the
token "OUTPUT".
OUTPUT"""
def generate_start_msg(end_msg, diff):
prompt = build_prompt(reference=end_msg, diff=diff)
results = []
for i in range(GENERATION_ATTEMPTS):
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg,
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred,
end_msg))
if stats["deletions"] < REL_DELETIONS_THRESHOLD:
return start_msg_pred
else:
results.append((stats["deletions"], start_msg_pred))
results.sort()
return results[0][1]
COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
def transform(df):
print(f"Start -> send synthesis:")
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
df['start_to_end'] = False
generated_data = {
"commit_msg_end": []
}
for col in COLS_TO_KEEP:
generated_data[col] = []
for _, row in tqdm(df.iterrows(), total=len(df)):
for i in range(GENERATION_MULTIPLIER):
commit_msg_end_pred = generate_start_msg(end_msg=row["commit_msg_start"],
diff=row["mods"])
generated_data["commit_msg_end"].append(commit_msg_end_pred)
for col in COLS_TO_KEEP:
generated_data[col].append(row[col])
generated_df = pd.DataFrame.from_dict(generated_data)
generated_df['start_to_end'] = True
result = pd.concat([df, generated_df], ignore_index=True)
result.to_csv(config.START_TO_END_ARTIFACT)
print("Done")
return result
def main():
df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0])
transform(df)
if __name__ == '__main__':
main()