|
import streamlit as st |
|
|
|
import numpy as np |
|
from numpy import ndarray |
|
import pandas as pd |
|
import torch as T |
|
from torch import Tensor, device |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig, AutoModel |
|
from nltk.corpus import stopwords |
|
from nltk.stem.porter import * |
|
import json |
|
import nltk |
|
from nltk import FreqDist |
|
from nltk.corpus import gutenberg |
|
import urllib.request |
|
from string import punctuation |
|
from math import log,exp,sqrt |
|
import random |
|
from time import sleep |
|
|
|
nltk.download('stopwords') |
|
nltk.download('gutenberg') |
|
|
|
cos = T.nn.CosineSimilarity(dim=0) |
|
|
|
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-info.txt", 'dict-info.txt') |
|
sleep(1) |
|
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-unix.txt", 'dict-unix.txt') |
|
sleep(1) |
|
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/profanity.json", 'profanity.json') |
|
|
|
|
|
|
|
|
|
|
|
CONTEXTUAL_EMBEDDING_LAYERS = [12] |
|
EXTEND_SUBWORDS=True |
|
MAX_SUBWORDS=1 |
|
DEBUG_OUTPUT=True |
|
DISTRACTORS_FROM_TEXT=False |
|
MIN_SENT_WORDS = 7 |
|
|
|
|
|
stemmer = PorterStemmer() |
|
freq = FreqDist(i.lower() for i in gutenberg.words()) |
|
print(freq.most_common()[:5]) |
|
|
|
words_unix = set(line.strip() for line in open('dict-unix.txt')) |
|
words_info = set(line.strip() for line in open('dict-info.txt')) |
|
words_small = words_unix.intersection(words_info) |
|
words_large = words_unix.union(words_info) |
|
f = open('profanity.json') |
|
profanity = json.load(f) |
|
|
|
import stanza |
|
|
|
nlp = stanza.Pipeline(lang='en', processors='tokenize') |
|
|
|
nltk.download('punkt') |
|
nltk_sent_toker = nltk.data.load('tokenizers/punkt/english.pickle') |
|
|
|
def is_word(str): |
|
'''Check if word exists in dictionary''' |
|
splt = str.lower().split("'") |
|
if len(splt) > 2: |
|
return False |
|
elif len(splt) == 2: |
|
return is_word(splt[0]) and (splt[1] in ['t','nt','s','ll']) |
|
elif '-' in str: |
|
for word in str.split('-'): |
|
if not is_word(word): |
|
return False |
|
return True |
|
else: |
|
return str.lower() in words_unix or str.lower() in words_info |
|
|
|
def get_emb(snt_toks, tgt_toks, layers=None): |
|
'''Embeds a group of subword tokens in place of a mask, using the entire |
|
sentence for context. Returns the average of the target token embeddings, |
|
which are summed over the hidden layers. |
|
|
|
snt_toks: the tokenized sentence, including the mask token |
|
tgt_toks: the tokens (subwords) to replace the mask token |
|
layers (optional): which hidden layers to sum (list of indices)''' |
|
mask_idx = snt_toks.index(toker.mask_token_id) |
|
snt_toks = snt_toks.copy() |
|
|
|
while mask_idx + len(tgt_toks)-1 >= 512: |
|
|
|
snt_toks = snt_toks[100:] |
|
mask_idx -= 100 |
|
|
|
snt_toks[mask_idx:mask_idx+1] = tgt_toks |
|
snt_toks = snt_toks[:512] |
|
with T.no_grad(): |
|
if T.cuda.is_available(): |
|
T.tensor([snt_toks]).cuda() |
|
T.tensor([[1]*len(snt_toks)]).cuda() |
|
output = model(T.tensor([snt_toks]), T.tensor([[1]*len(snt_toks)]), output_hidden_states=True) |
|
layers = CONTEXTUAL_EMBEDDING_LAYERS if layers is None else layers |
|
output = T.stack([output.hidden_states[i] for i in layers]).sum(0).squeeze() |
|
|
|
return output[mask_idx:mask_idx+len(tgt_toks)].mean(dim=0) |
|
|
|
def energy(ctx, scaled_dists, scaled_sims, choices, words, ans): |
|
|
|
|
|
'''Cost function to help choose best distractors''' |
|
|
|
|
|
|
|
hm_sim = 0 |
|
e_ctx = 0 |
|
for i in choices: |
|
hm_sim += 1./scaled_sims[i] |
|
e_ctx += ctx[i] |
|
|
|
e_sim = float(len(choices))/hm_sim |
|
|
|
hm_emb = 0 |
|
count = 0 |
|
c = choices + [len(ctx)] |
|
for i in range(len(c)): |
|
for j in range(i): |
|
d = scaled_dists['%s-%s'%(max(c[i],c[j]), min(c[i], c[j]))] |
|
|
|
hm_emb += 1./d |
|
count += 1 |
|
e_emb = float(count)/hm_emb |
|
return float(e_emb), e_ctx, float(e_sim) |
|
|
|
def anneal(probs_sent_context, probs_para_context, embs, emb_ans, words, k, ans): |
|
'''find k distractor indices that are optimally high probability and distant |
|
in embedding space''' |
|
|
|
m = len(probs_sent_context) |
|
|
|
its = 1000 |
|
n = len(probs_para_context) |
|
choices = list(range(k)) |
|
|
|
dists = {} |
|
embsa = embs + [emb_ans] |
|
for i in range(len(embsa)): |
|
for j in range(i): |
|
dists['%s-%s'%(i,j)] = 1-cos(embsa[i], embsa[j]) |
|
|
|
|
|
dist_min = T.min(T.tensor(list(dists.values()))) |
|
dist_max = T.max(T.tensor(list(dists.values()))) |
|
for key, dist in dists.items(): |
|
dists[key] = (dist - dist_min)/(dist_max-dist_min) |
|
|
|
sims = T.tensor([cos(emb_ans, emb) for emb in embs]) |
|
scaled_sims = (sims - T.min(sims))/(T.max(sims)-T.min(sims)) |
|
|
|
ctx = T.tensor(probs_sent_context).log()-ALPHA*T.tensor(probs_para_context).log() |
|
ctx = (ctx-T.min(ctx))/(T.max(ctx)-T.min(ctx)) |
|
|
|
e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans) |
|
e = e_ctx + BETA * e_emb |
|
|
|
for i in range(its): |
|
t = 1.-(i)/its |
|
mut_idx = random.randrange(k) |
|
orig = choices[mut_idx] |
|
new = orig |
|
while (new in choices): |
|
new = random.randrange(m) |
|
choices[mut_idx] = new |
|
e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans) |
|
e_new = e_ctx + BETA * e_emb |
|
delta = e_new - e |
|
exponent = delta/t |
|
if exponent < -50: |
|
exponent = -50 |
|
if delta > 0 or exp(exponent) > random.random(): |
|
e = e_new |
|
else: |
|
choices[mut_idx] = orig |
|
if DEBUG_OUTPUT: |
|
print([words[j] for j in choices] + [ans], "e: %f"%(e)) |
|
return choices |
|
|
|
def get_softmax_logits(toks, n_masks = 1, sub_ids = []): |
|
|
|
msk_idx = toks.index(toker.mask_token_id) |
|
toks = toks.copy() |
|
toks[msk_idx:msk_idx+1] = [toker.mask_token_id] * n_masks + sub_ids |
|
|
|
|
|
while msk_idx >= 512: |
|
|
|
toks = toks[100:] |
|
msk_idx -= 100 |
|
toks = toks[:512] |
|
|
|
with T.no_grad(): |
|
t=T.tensor([toks]) |
|
m=T.tensor([[1]*len(toks)]) |
|
if T.cuda.is_available(): |
|
t.cuda() |
|
m.cuda() |
|
output = model(t, m) |
|
sm = T.softmax(output.logits[0, msk_idx:msk_idx+n_masks, :], dim=1) |
|
return sm |
|
|
|
e=1e-10 |
|
|
|
def candidates(text, answer): |
|
'''Create list of unique distractors that does not include the actual answer''' |
|
if DEBUG_OUTPUT: |
|
print(text) |
|
|
|
|
|
doc = nlp(text) |
|
|
|
sents = nltk_sent_toker.tokenize(text) |
|
msk_snt_idx = [i for i in range(len(sents)) if toker.mask_token in sents[i]][0] |
|
just_masked_sentence = sents[msk_snt_idx] |
|
|
|
prv_snts = sents[:msk_snt_idx] |
|
nxt_snts = sents[msk_snt_idx+1:] |
|
|
|
if len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and len(prv_snts): |
|
just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence]) |
|
|
|
while len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and (len(prv_snts) or len(nxt_snts)): |
|
if T.rand(1) < 0.5 and len(prv_snts): |
|
just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence]) |
|
elif len(nxt_snts): |
|
just_masked_sentence = ' '.join([just_masked_sentence, nxt_snts.pop(0)]) |
|
|
|
ctx = just_masked_sentence |
|
while len(ctx.split(' ')) < 3 * len(just_masked_sentence.split(' ')) and (len(prv_snts) or len(nxt_snts)): |
|
if len(prv_snts): |
|
ctx = ' '.join([prv_snts.pop(), ctx]) |
|
if len(nxt_snts): |
|
ctx = ' '.join([ctx, nxt_snts.pop(0)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tiled = just_masked_sentence |
|
while len(tiled) < len(text): |
|
tiled += ' ' + just_masked_sentence |
|
just_masked_sentence = tiled |
|
|
|
if DEBUG_OUTPUT: |
|
print(ctx) |
|
print(text) |
|
print(just_masked_sentence) |
|
toks_para = toker.encode(text) |
|
toks_sent = toker.encode(just_masked_sentence) |
|
|
|
|
|
|
|
|
|
sent_sms_all = [] |
|
para_sms_all = [] |
|
para_sms_right = [] |
|
|
|
for i in range(MAX_SUBWORDS): |
|
para_sms = get_softmax_logits(toks_para, i + 1) |
|
para_sms_all.append(para_sms) |
|
sent_sms = get_softmax_logits(toks_sent, i + 1) |
|
sent_sms_all.append(sent_sms) |
|
para_sms_right.append(T.exp((sent_sms[i].log()+para_sms[i].log())/2) * (suffix_mask_inv if i == 0 else suffix_mask)) |
|
|
|
|
|
para_sm_best, para_pos_best = T.max(T.vstack(para_sms_right), 0) |
|
|
|
distractors = [] |
|
stems = [] |
|
embs = [] |
|
sent_probs = [] |
|
para_probs = [] |
|
|
|
ans_stem = stemmer.stem(answer.lower()) |
|
|
|
emb_ans = get_emb(toks_para, toker(answer)['input_ids'][1:-1]) |
|
para_words = text.lower().split(' ') |
|
blank_word_idx = [idx for idx, word in enumerate(text.split(' ')) if toker.mask_token in word][0] |
|
if (blank_word_idx - 1) < 0: |
|
prev_word = 'beforeanytext' |
|
else: |
|
prev_word = para_words[blank_word_idx-1] |
|
if (blank_word_idx + 1) >= len(para_words): |
|
next_word = 'afteralltext' |
|
else: |
|
next_word = para_words[blank_word_idx+1] |
|
|
|
|
|
if len(para_sms_all[0]) > 0: |
|
top_ctx = T.topk((sent_sms_all[0][0]*word_mask+e).log() - ALPHA * (para_sms_all[0][0]*word_mask+e).log(), len(para_sms_all[0][0]), dim=0) |
|
para_top_ids = top_ctx.indices.tolist() |
|
para_top_probs = top_ctx.values.tolist() |
|
|
|
for i, id in enumerate(para_top_ids): |
|
|
|
sub_ids = [int(id)] |
|
dec = toker.decode(sub_ids).strip() |
|
if DEBUG_OUTPUT: |
|
print('Trying:', dec) |
|
|
|
|
|
|
|
|
|
if dec.isupper() != answer.isupper(): |
|
continue |
|
|
|
if EXTEND_SUBWORDS and para_pos_best[id] > 0: |
|
if DEBUG_OUTPUT: |
|
print("Extending %s with %d masks..."%(dec, para_pos_best[id])) |
|
ext_ids, _ = extend(toks_sent, toks_para, [id], para_pos_best[id], para_words) |
|
sub_ids = ext_ids + sub_ids |
|
dec_ext = toker.decode(sub_ids).strip() |
|
if DEBUG_OUTPUT: |
|
print("Extended %s to %s"%(dec, dec_ext)) |
|
if is_word(dec_ext) or (dec_ext != '' and dec_ext in para_words): |
|
dec = dec_ext |
|
else: |
|
sub_ids = [int(id)] |
|
|
|
if len(toker.decode(sub_ids).lower().strip()) < 2: |
|
continue |
|
|
|
if dec[0].isupper() != answer[0].isupper(): |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
if dec.lower() in profanity: |
|
continue |
|
|
|
|
|
if not is_word(dec) and dec.lower() not in para_words: |
|
continue |
|
|
|
|
|
if dec.lower() == prev_word or dec.lower() == next_word: |
|
continue |
|
|
|
|
|
stem = stemmer.stem(dec).lower() |
|
if stem in stems or stem == ans_stem: |
|
continue |
|
|
|
|
|
if any(char.isdigit() for char in toker.decode([id])): |
|
continue |
|
|
|
|
|
if DISTRACTORS_FROM_TEXT and dec.lower() not in para_words: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
distractors.append(dec) |
|
stems.append(stem) |
|
sent_logprob = 0 |
|
para_logprob = 0 |
|
nsubs = len(sub_ids) |
|
for j in range(nsubs): |
|
sub_id = sub_ids[j] |
|
sent_logprob_j = log(sent_sms_all[nsubs-1][j][sub_id]) |
|
para_logprob_j = log(para_sms_all[nsubs-1][j][sub_id]) |
|
|
|
|
|
|
|
|
|
sent_logprob += sent_logprob_j |
|
para_logprob += para_logprob_j |
|
sent_logprob /= nsubs |
|
para_logprob /= nsubs |
|
if DEBUG_OUTPUT: |
|
print("%s (p_sent=%f, p_para=%f)"%(dec,sent_logprob,para_logprob)) |
|
sent_probs.append(exp(sent_logprob)) |
|
para_probs.append(exp(para_logprob)) |
|
|
|
|
|
embs.append(get_emb(toks_para, sub_ids)) |
|
|
|
if len(distractors) >= K: |
|
break |
|
if DEBUG_OUTPUT: |
|
print('Corresponding Text: ', text) |
|
print('Correct Answer: ', answer) |
|
print('Distractors created before annealing: ', distractors) |
|
|
|
|
|
|
|
|
|
return sent_probs, para_probs, embs, emb_ans, distractors |
|
|
|
def create_distractors(text, answer): |
|
sent_probs, para_probs, embs, emb_ans, distractors = candidates(text, answer) |
|
|
|
indices = anneal(sent_probs, para_probs, embs, emb_ans, distractors, 3, answer) |
|
return [distractors[x] for x in indices] |
|
|
|
st.title("nCloze") |
|
st.subheader("Create a multiple-choice cloze test from a passage") |
|
st.markdown("Note: this is a free, CPU-only space and will be slow. For better performance, clone the space with a GPU-enabled environment.") |
|
|
|
def blank(tok): |
|
if tok == 'a(n)': |
|
strp = tok |
|
else: |
|
strp = tok.strip(punctuation) |
|
print(strp, tok.replace(strp, toker.mask_token)) |
|
return strp, tok.replace(strp, toker.mask_token) |
|
|
|
test = """In contrast to necrosis, which is a form of traumatic cell death that results from acute cellular injury, apoptosis is a highly regulated and controlled process that confers advantages during an organism's life cycle. For example, the separation of fingers and toes in a developing human embryo occurs because cells between the digits undergo apoptosis. Unlike necrosis, apoptosis produces cell fragments called apoptotic bodies that phagocytes are able to engulf and remove before the contents of the cell can spill out onto surrounding cells and cause damage to them.""" |
|
st.header("Basic options") |
|
SPACING = int(st.text_input('Blank spacing', value="7")) |
|
OFFSET = int(st.text_input('First word to blank (0 to use spacing)', value="0")) |
|
st.header("Advanced options") |
|
ALPHA = float(st.text_input('Incorrectness weight', value="0.75")) |
|
BETA = float(st.text_input('Distinctness weight', value="0.75")) |
|
MODEL_TYPE = st.text_input('Masked Language Model (from HuggingFace)', value="roberta-large") |
|
K = 16 |
|
|
|
model = AutoModelForMaskedLM.from_pretrained(MODEL_TYPE) |
|
|
|
if T.cuda.is_available(): |
|
model.cuda() |
|
|
|
toker = AutoTokenizer.from_pretrained(MODEL_TYPE, add_prefix_space=True) |
|
|
|
sorted_toker_vocab_dict = sorted(toker.vocab.items(), key=lambda x:x[1]) |
|
if toker.mask_token == '[MASK]': |
|
suffix_mask = T.FloatTensor([1 if (('##' == x[0][:2]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) |
|
else: |
|
suffix_mask = T.FloatTensor([1 if (('Ġ' != x[0][0]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) |
|
suffix_mask_inv = suffix_mask * -1 + 1 |
|
word_mask = suffix_mask_inv*T.FloatTensor([1 if is_word(x[0][1:]) and x[0][1:].lower() not in profanity else 0 for x in sorted_toker_vocab_dict]) |
|
if T.cuda.is_available(): |
|
suffix_mask=suffix_mask.cuda() |
|
suffix_mask_inv=suffix_mask_inv.cuda() |
|
word_mask = word_mask.cuda() |
|
|
|
st.subheader("Passage") |
|
st.text_area('Passage to create a cloze test from:',value=test,key="text", max_chars=1024, height=275) |
|
|
|
def generate(): |
|
ws = st.session_state.text.split() |
|
wb = st.session_state.text.split() |
|
|
|
qs = [] |
|
i = OFFSET - 1 if OFFSET > 0 else SPACING |
|
j = 0 |
|
while i < len(ws): |
|
a, b = blank(ws[i]) |
|
while b == '' and i < len(ws)-1: |
|
i += 1 |
|
a, b = blank(ws[i]) |
|
if b != '': |
|
w = ws[i] |
|
ws[i] = b |
|
wb[i] = b |
|
|
|
while j<i: |
|
yield(' ' + ws[j]) |
|
j += 1 |
|
masked = ' '.join(ws) |
|
|
|
ds = create_distractors(masked, a) |
|
print(ds, a) |
|
q = ds+[a+'\*'] |
|
random.shuffle(q) |
|
yield(b.replace(toker.mask_token,' **['+', '.join(q)+']**')) |
|
j+=1 |
|
qs.append(ds) |
|
ws[i] = w |
|
i += SPACING |
|
while j<len(ws): |
|
yield(' ' + ws[j]) |
|
j += 1 |
|
|
|
|
|
if st.button("Generate"): |
|
|
|
st.write_stream(generate()) |
|
|