import sys import requests import json import pandas as pd SOCIOFILLMORE_API = "http://127.0.0.1:5000" AUTH_KEY = "3TrJ397oh#^" def get_sample(s, dataset, n_samples, frame, construction, role, dependency): s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) r_q = s.get( SOCIOFILLMORE_API + "/sample_frame", params={ "auth_key": AUTH_KEY, "frame": frame, "construction": construction, "role": role, "dependency": dependency, "model": "lome_0shot", "n": n_samples, }, ) data = json.loads(r_q.text) rows_out = [] for sent in data: for fns in sent["fn_structures"]: if fns["frame"] == frame: target_roles = [r for r in fns["roles"] if r[0] == role] if target_roles: target_role = target_roles[0] else: continue rows_out.append( { "dataset": dataset, "sentence": " ".join(sent["sentence"]), "frame": frame, "target": " ".join(fns["target"]["tokens_str"]), "role_label": role, "role_span": " ".join(target_role[1]["tokens_str"]), "dependency": dependency, } ) return rows_out def get_labels(s, dataset, frame): s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) r_q = s.get( SOCIOFILLMORE_API + "/frame_freq", params={ "auth_key": AUTH_KEY, "model": "lome_0shot", "frames": frame, "constructions": "", "group_by_cat": "n", "group_by_constr": "n", "group_by_role_expr": 2, "relative": "y", "plot_over_days_post": "n", }, ) data = json.loads(r_q.text) return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]} def main(language): s = requests.Session() if language == "it": print("Finding IT labels...") labels_it = get_labels(s, "femicides/rai", "Killing") sample_rows_it = [] for label in sorted(labels_it): if label == "_UNK_DEP": continue print(f"Label (IT): {label}") sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label)) sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label)) df_samples_it = pd.DataFrame(sample_rows_it) df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv") if language == "nl": print("Finding NL labels...") labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm") sample_rows_nl = [] for label in sorted(labels_nl): if label == "_UNK_DEP": continue print(f"Label (NL): {label}") sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label)) sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label)) df_samples_nl = pd.DataFrame(sample_rows_nl) df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv") if __name__ == "__main__": main(language=sys.argv[1])