|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
import config |
|
import generate_annotated_diffs |
|
import statistics |
|
from api_wrappers import grazie_wrapper |
|
from generation_steps import examples |
|
|
|
GENERATION_MULTIPLIER = 3 |
|
REL_DELETIONS_THRESHOLD = 0.75 |
|
GENERATION_ATTEMPTS = 5 |
|
|
|
|
|
def build_prompt(reference, diff): |
|
return f"""A LLM generated a commit message for the following source code changes: |
|
START OF THE SOURCE CODE CHANGES |
|
{diff} |
|
END OF THE SOURCE CODE CHANGES |
|
|
|
Here is the message the LLM generated: |
|
START OF THE COMMIT MESSAGE |
|
{reference} |
|
END OF THE COMMIT MESSAGE |
|
|
|
This generated message is not perfect. Your task is to rewrite and improve it. |
|
You have to simulate a human software developer who manually rewrites the LLM-generated commit message, |
|
so the message you print must share some fragments with the generated message. |
|
Your message should be concise. |
|
Follow the Conventional Commits guidelines. |
|
Here are some examples of what you should output: |
|
START OF THE EXAMPLES LIST |
|
{examples.EXAMPLES_START_TO_END} |
|
END OF THE EXAMPLES LIST |
|
|
|
|
|
Print only the improved commit message's text after the |
|
token "OUTPUT". |
|
|
|
OUTPUT""" |
|
|
|
|
|
def generate_start_msg(end_msg, diff): |
|
prompt = build_prompt(reference=end_msg, diff=diff) |
|
results = [] |
|
|
|
for i in range(GENERATION_ATTEMPTS): |
|
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt) |
|
|
|
stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg, |
|
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred, |
|
end_msg)) |
|
if stats["deletions"] < REL_DELETIONS_THRESHOLD: |
|
return start_msg_pred |
|
else: |
|
results.append((stats["deletions"], start_msg_pred)) |
|
|
|
results.sort() |
|
return results[0][1] |
|
|
|
|
|
COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"] |
|
|
|
|
|
def transform(df): |
|
print(f"Start -> send synthesis:") |
|
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") |
|
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") |
|
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}") |
|
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") |
|
|
|
df['start_to_end'] = False |
|
|
|
generated_data = { |
|
"commit_msg_end": [] |
|
} |
|
|
|
for col in COLS_TO_KEEP: |
|
generated_data[col] = [] |
|
|
|
for _, row in tqdm(df.iterrows(), total=len(df)): |
|
for i in range(GENERATION_MULTIPLIER): |
|
commit_msg_end_pred = generate_start_msg(end_msg=row["commit_msg_start"], |
|
diff=row["mods"]) |
|
|
|
generated_data["commit_msg_end"].append(commit_msg_end_pred) |
|
for col in COLS_TO_KEEP: |
|
generated_data[col].append(row[col]) |
|
|
|
generated_df = pd.DataFrame.from_dict(generated_data) |
|
generated_df['start_to_end'] = True |
|
|
|
result = pd.concat([df, generated_df], ignore_index=True) |
|
result.to_csv(config.START_TO_END_ARTIFACT) |
|
|
|
print("Done") |
|
return result |
|
|
|
|
|
def main(): |
|
df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0]) |
|
transform(df) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|