bigbio_test / ngram.py
tensorized
testing scitail
50e5fc3
# partially from https://gist.github.com/gaulinmp/da5825de975ed0ea6a24186434c24fe4
from nltk.util import ngrams
from nltk.corpus import stopwords
import spacy
import pandas as pd
import re
from itertools import chain
from collections import Counter
from datasets import load_dataset
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("sentencizer")
STOPWORDS = nlp.Defaults.stop_words
N = 5
re_sent_ends_naive = re.compile(r'[.\n]')
re_stripper_naive = re.compile('[^a-zA-Z\.\n]')
splitter_naive = lambda x: re_sent_ends_naive.split(re_stripper_naive.sub(' ', x))
# list of tokens for one sentence
def remove_stop_words(text):
result = []
for w in text:
if w not in STOPWORDS:
result.append(w)
return result
# get sentence from multiple sentences
def parse_sentences(text, nlp):
doc = nlp(text)
sentences = (remove_stop_words(sent) for sent in doc.sents)
return sentences
def get_tuples_manual_sentences(txt, N):
"""Naive get tuples that uses periods or newlines to denote sentences."""
if not txt:
return None, []
sentences = (x.split() for x in splitter_naive(txt) if x)
sentences = list(map(remove_stop_words, list(sentences)))
# sentences = (remove_stop_words(nlp(x)) for x in splitter_naive(txt) if x)
# sentences = parse_sentences(txt, nlp)
# print(list(sentences))
ng = (ngrams(x, N) for x in sentences if len(x) >= N)
return sentences, list(chain(*ng))
def count_by_split(split_data):
c = Counter()
for entry in split_data:
text = entry['text']
sents, tup = get_tuples_manual_sentences(text, N)
tup = ["_".join(ta) for ta in tup]
c.update(tup)
return c
# data = load_dataset("bigbio/biodatasets/chemdner/chemdner.py", name="chemdner_bigbio_text")
# counters = []
# for split, split_data in data.items():
# split_counter = count_by_split(split_data)
# counters.append(split_counter)
# ab_intersect = counters[0] & counters[1]
# diff = {x: count for x, count in counters[0].items() if x not in ab_intersect.keys() and count > 2}
# if len(counters) > 2:
# bc_intersect = counters[1] & counters[2]
# print(ab_intersect.most_common(10))
# print(Counter(diff).most_common(10))
# data.cleanup_cache_files()