|
import json |
|
|
|
from tqdm import tqdm |
|
|
|
import config |
|
from api_wrappers import hf_data_loader |
|
from generation_steps import synthetic_start_to_end |
|
|
|
|
|
def transform(df): |
|
print(f"Generating data for labeling:") |
|
synthetic_start_to_end.print_config() |
|
tqdm.pandas() |
|
|
|
manual_df = hf_data_loader.load_raw_rewriting_as_pandas() |
|
|
|
manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE |
|
).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]] |
|
|
|
manual_df = manual_df[~manual_df.index.duplicated(keep='first')] |
|
|
|
def get_is_manually_rewritten(row): |
|
commit_id = (row['hash'], row['repo']) |
|
return commit_id in manual_df.index |
|
|
|
result = df |
|
result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1) |
|
|
|
def get_prediction_message(row): |
|
commit_id = (row['hash'], row['repo']) |
|
if row['manual_sample']: |
|
return manual_df.loc[commit_id]['commit_msg_start'] |
|
return row['prediction'] |
|
|
|
def get_enhanced_message(row): |
|
commit_id = (row['hash'], row['repo']) |
|
if row['manual_sample']: |
|
return manual_df.loc[commit_id]['commit_msg_end'] |
|
return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"], |
|
diff=row["mods"]) |
|
|
|
result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1) |
|
result['prediction'] = result.progress_apply(get_prediction_message, axis=1) |
|
result['mods'] = result['mods'].progress_apply(json.dumps) |
|
|
|
result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) |
|
print("Done") |
|
return result |
|
|
|
|
|
def main(): |
|
synthetic_start_to_end.GENERATION_ATTEMPTS = 3 |
|
df = hf_data_loader.load_full_commit_with_predictions_as_pandas() |
|
transform(df) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|