Gosse Minnema
Add sociofillmore code, load dataset via private dataset repo
b11ac48
raw
history blame
No virus
5.2 kB
"""
Learn to classify the manually annotated CDA attributes (frames, 'riferimento', orientation)
"""
GLOVE_MODEL = "/net/aistaff/gminnema/thesis_data/data/glove-it/glove_WIKI"
from sklearn.svm import SVC
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import precision_recall_fscore_support
import gensim
import pandas as pd
import spacy
import json
def train(attrib):
assert attrib in ["cda_frame", "riferimento", "orientation"]
# load data
print("Loading data...")
x_train, y_train, x_dev, y_dev = load_data(attrib)
print(f"\t\ttrain size: {len(x_train)}")
print(f"\t\tdev size: {len(x_dev)}")
# try different setups
print("Running training setups...")
scores = []
setups = [
# defaults: remove_punct=True, lowercase=True, lemmatize=False, remove_stop=False
# ({}, {}, SVC(kernel='linear')),
# ({"lemmatize": True, "remove_stop": True}, {}, SVC(kernel='linear')),
# ({"lemmatize": True, "remove_stop": True}, {"min_freq": 5}, SVC(kernel='linear')),
# ({"lemmatize": True, "remove_stop": True}, {"min_freq": 5, "max_freq": .70}, SVC(kernel='linear')),
# ({"lemmatize": True, "remove_stop": True}, {}, SVC(kernel='linear', C=0.6)),
# ({"lemmatize": True, "remove_stop": True}, {}, SVC(kernel='linear', C=0.7)),
# ({"lemmatize": True, "remove_stop": True}, {}, SVC(kernel='linear', C=0.8)),
({"lemmatize": True, "remove_stop": True}, {"embed": "glove"}, SVC(kernel='linear', C=0.8)),
# ({"lemmatize": True, "remove_stop": True}, {}, SVC(kernel="rbf")),
]
nlp = spacy.load("it_core_news_md")
for s_idx, (text_options, vect_options, model) in enumerate(setups):
print(f"\tSetup #{s_idx}")
# extract features
print("\t\tExtracting features...")
x_train_fts, vectorizer = extract_features(x_train, nlp, text_options, **vect_options)
x_dev_fts, _ = extract_features(x_dev, nlp, text_options, **vect_options, vectorizer=vectorizer)
print(f"\t\t\tnum features: {len(vectorizer.vocabulary_)}")
print("\t\tTraining the model...")
model.fit(x_train_fts, y_train)
# evaluate on dev
print("\t\tValidating the model...")
y_dev_pred = model.predict(x_dev_fts)
p_micro, r_micro, f_micro, _ = precision_recall_fscore_support(
y_dev, y_dev_pred, average="micro")
p_classes, r_classes, f_classes, _ = precision_recall_fscore_support(
y_dev, y_dev_pred, average=None, labels=model.classes_, zero_division=0)
print(
f"\t\t\tOverall scores (micro-averaged):\tP={p_micro}\tR={r_micro}\tF={f_micro}"
)
scores.append({
"micro": {
"p": p_micro,
"r": r_micro,
"f": f_micro
},
"classes": {
"p": list(zip(model.classes_, p_classes)),
"r": list(zip(model.classes_, r_classes)),
"f": list(zip(model.classes_, f_classes)),
}
})
prediction_df = pd.DataFrame(zip(x_dev, y_dev, y_dev_pred), columns=["headline", "gold", "prediction"])
prediction_df.to_csv(f"output/migration/cda_classify/predictions_{s_idx:02}.csv")
with open("output/migration/cda_classify/scores.json", "w", encoding="utf-8") as f_scores:
json.dump(scores, f_scores, indent=4)
def load_data(attrib):
train_data = pd.read_csv(
"output/migration/preprocess/annotations_train.csv")
dev_data = pd.read_csv("output/migration/preprocess/annotations_dev.csv")
x_train = train_data["Titolo"]
x_dev = dev_data["Titolo"]
if attrib == "cda_frame":
y_train = train_data["frame"]
y_dev = dev_data["frame"]
elif attrib == "riferimento":
y_train = train_data["riferimento"]
y_dev = dev_data["riferimento"]
else:
x_train = train_data["orientation"]
y_dev = dev_data["orientation"]
return x_train, y_train, x_dev, y_dev
def extract_features(headlines, nlp, text_options, min_freq=1, max_freq=1.0, embed=None, vectorizer=None):
tokenized = [" ".join(sent) for sent in tokenize(headlines, nlp, **text_options)]
if vectorizer is None:
if embed is None:
vectorizer = CountVectorizer(lowercase=False, analyzer="word", min_df=min_freq, max_df=max_freq)
vectorized = vectorizer.fit_transform(tokenized)
else:
vectorizer = gensim.models.
else:
vectorized = vectorizer.transform(tokenized)
return vectorized, vectorizer
def tokenize(headlines, nlp, remove_punct=True, lowercase=True, lemmatize=False, remove_stop=False):
for sent in headlines:
doc = nlp(sent)
tokens = (
t.lemma_ if lemmatize else t.text
for t in doc
if (not remove_stop or not t.is_stop) and (not remove_punct or t.pos_ not in ["PUNCT", "SYM", "X"])
)
if lowercase:
tokens = [t.lower() for t in tokens]
else:
tokens = [t for t in tokens]
yield tokens
if __name__ == '__main__':
train(attrib="cda_frame")