File size: 3,537 Bytes
b11ac48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys
import requests
import json

import pandas as pd


SOCIOFILLMORE_API = "http://127.0.0.1:5000"
AUTH_KEY = "3TrJ397oh#^"


def get_sample(s, dataset, n_samples, frame, construction, role, dependency):

    s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})

    r_q = s.get(
        SOCIOFILLMORE_API + "/sample_frame",
        params={
            "auth_key": AUTH_KEY,
            "frame": frame,
            "construction": construction,
            "role": role,
            "dependency": dependency,
            "model": "lome_0shot",
            "n": n_samples,
        },
    )

    data = json.loads(r_q.text)

    rows_out = []

    for sent in data:
        for fns in sent["fn_structures"]:
            if fns["frame"] == frame:
                target_roles = [r for r in fns["roles"] if r[0] == role]
                if target_roles:
                    target_role = target_roles[0]
                else:
                    continue

                rows_out.append(
                    {
                        "dataset": dataset,
                        "sentence": " ".join(sent["sentence"]),
                        "frame": frame,
                        "target": " ".join(fns["target"]["tokens_str"]),
                        "role_label": role,
                        "role_span": " ".join(target_role[1]["tokens_str"]),
                        "dependency": dependency,
                    }
                )

    return rows_out


def get_labels(s, dataset, frame):
    s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})

    r_q = s.get(
        SOCIOFILLMORE_API + "/frame_freq",
        params={
            "auth_key": AUTH_KEY,
            "model": "lome_0shot",
            "frames": frame,
            "constructions": "",
            "group_by_cat": "n",
            "group_by_constr": "n",
            "group_by_role_expr": 2,
            "relative": "y",
            "plot_over_days_post": "n",
        },
    )

    data = json.loads(r_q.text)
    return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]}


def main(language):

    s = requests.Session()


    if language == "it":
        print("Finding IT labels...")
        labels_it = get_labels(s, "femicides/rai", "Killing")    
        sample_rows_it = []
        for label in sorted(labels_it):

            if label == "_UNK_DEP":
                continue

            print(f"Label (IT): {label}")
            sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label))
            sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label))

        df_samples_it = pd.DataFrame(sample_rows_it)
        df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv")

    if language == "nl":
        print("Finding NL labels...")
        labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm")    
        sample_rows_nl = []
        for label in sorted(labels_nl):

            if label == "_UNK_DEP":
                continue

            print(f"Label (NL): {label}")
            sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label))
            sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label))
        df_samples_nl = pd.DataFrame(sample_rows_nl)
        df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv")


if __name__ == "__main__":
    main(language=sys.argv[1])