sociofillmore_public / sociofillmore /migration /split_lome_predictions.py
Gosse Minnema
Add sociofillmore code, load dataset via private dataset repo
b11ac48
raw
history blame
No virus
2.34 kB
import os
import json
import pandas as pd
def main(input_json, input_txt, output_dir):
meta_df = pd.read_csv("output/migration/split_data/split_dev10.texts.meta.csv")
text_ids = meta_df["text_id"].to_list()
with open(input_json, encoding="utf-8") as f:
json_predictions = json.load(f)
with open(input_txt, encoding="utf-8") as f:
txt_predictions = f.read().split("\n\n")
for t_id, json_p, txt_p in zip(text_ids, json_predictions, txt_predictions):
if int(t_id) % 100 == 0:
print(t_id)
prediction_dir = f"{output_dir}/{t_id}"
if not os.path.isdir(prediction_dir):
os.makedirs(prediction_dir)
prediction_file_json = f"{prediction_dir}/lome_{t_id}.comm.json"
prediction_file_txt = f"{prediction_dir}/lome_{t_id}.comm.txt"
with open(prediction_file_json, "w", encoding="utf-8") as f_out:
json.dump([json_p], f_out)
with open(prediction_file_txt, "w", encoding="utf-8") as f_out:
f_out.write(txt_p + "\n\n")
if __name__ == "__main__":
# main(
# input_json="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_tc.comm.json",
# input_txt="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_tc.comm.txt",
# output_dir="output/migration/lome/multilabel/lome_0shot/pavia"
# )
# main(
# input_json="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_best-truecase.comm.json",
# input_txt="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_best-truecase.comm.txt",
# output_dir="output/migration/lome/multilabel/lome_0shot/pavia"
# )
# main(
# input_json="output/migration/lome/lome_zs-tgt_ev-frm/data-in.concat.combined_zs_ev.tc_bilstm.json",
# input_txt="output/migration/lome/lome_zs-tgt_ev-frm/data-in.concat.combined_zs_ev.tc_bilstm.txt",
# output_dir="output/migration/lome/multilabel/lome_zs-tgt_ev_frm/pavia"
# )
main(
input_json="/home/gossminn/WorkSyncs/Code/fn-for-social-frames/output/migration/lome/lome_migration_concat.comm.json",
input_txt="/home/gossminn/WorkSyncs/Code/fn-for-social-frames/output/migration/lome/lome_migration_concat.comm.txt",
output_dir="output/migration/lome/multilabel/lome_0shot/pavia"
)