|
from grazie.api.client.chat.prompt import ChatPrompt |
|
from grazie.api.client.endpoints import GrazieApiGatewayUrls |
|
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType |
|
from grazie.api.client.profiles import LLMProfile |
|
from tqdm import tqdm |
|
|
|
import config |
|
import hf_data_loader |
|
|
|
client = GrazieApiGatewayClient( |
|
grazie_agent=GrazieAgent(name="commit-rewriting-summary-generation", version="dev"), |
|
url=GrazieApiGatewayUrls.STAGING, |
|
auth_type=AuthType.SERVICE, |
|
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN |
|
) |
|
|
|
|
|
def build_prompt(reference, diff): |
|
return f"""A software developer uses a LLM to generate commit messages. |
|
|
|
They generated a commit message for the following source code changes: |
|
START OF THE SOURCE CODE CHANGES |
|
{diff} |
|
END OF THE SOURCE CODE CHANGES |
|
|
|
After generating the commit message the developer understands that it is not perfect. After making dome changes, |
|
they come up with an edited version of the message. Here is this edited message: |
|
START OF THE COMMIT MESSAGE |
|
{reference} |
|
END OF THE COMMIT MESSAGE |
|
|
|
Your task is to print the initial, LLM-generated commit message. Print only the initial commit message's text after the |
|
token "OUTPUT". |
|
|
|
OUTPUT""" |
|
|
|
|
|
def generate_prompt_for_row(row): |
|
reference = row['reference'] |
|
diff = row['mods'] |
|
return build_prompt(reference, diff) |
|
|
|
|
|
def generate_initial_msg(prompt): |
|
commit_msg = client.chat( |
|
chat=ChatPrompt() |
|
.add_system("You are a helpful assistant.") |
|
.add_user(prompt), |
|
profile=LLMProfile("gpt-4-1106-preview") |
|
).content |
|
|
|
return commit_msg |
|
|
|
|
|
def generate_synthetic_dataset(): |
|
df = hf_data_loader.load_full_commit_dataset_as_pandas() |
|
df['initial_msg_prompt'] = df.apply(generate_prompt_for_row, axis=1) |
|
initial_messages_pred = [] |
|
|
|
for prompt in tqdm(df['initial_msg_prompt']): |
|
initial_messages_pred.append(generate_initial_msg(prompt)) |
|
|
|
df['initial_msg_pred'] = initial_messages_pred |
|
|
|
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT) |
|
|
|
|
|
if __name__ == '__main__': |
|
generate_synthetic_dataset() |
|
|