File size: 20,071 Bytes
f440398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec8ae9
f440398
 
 
 
 
 
cb959e4
1ec8ae9
cb959e4
1ec8ae9
cb959e4
f440398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236be87
f440398
 
 
 
 
 
 
 
 
 
 
 
489d9e7
f440398
 
489d9e7
f440398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489d9e7
f440398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
import streamlit as st

import numpy as np
from numpy import ndarray
import pandas as pd
import torch as T
from torch import Tensor, device
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig, AutoModel
from nltk.corpus import stopwords
from nltk.stem.porter import *
import json
import nltk
from nltk import FreqDist
from nltk.corpus import gutenberg
import urllib.request
from string import punctuation
from math import log,exp,sqrt
import random
from time import sleep

nltk.download('stopwords')
nltk.download('gutenberg')

cos = T.nn.CosineSimilarity(dim=0)

urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-info.txt", 'dict-info.txt')
sleep(1)
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-unix.txt", 'dict-unix.txt')
sleep(1)
urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/profanity.json", 'profanity.json')

#gdown.download('https://drive.google.com/uc?id=16j6oQbqIUfdY1kMFOonXVDdG7A0C6CXD&confirm=t',use_cookies=True)
#gdown.download(id='13-3DyP4Df1GzrdQ_W4fLhPYAA1Gscg1j',use_cookies=True)
#gdown.download(id='180X6ztER2lKVP_dKinSJNE0XRtmnixAM',use_cookies=True)

CONTEXTUAL_EMBEDDING_LAYERS = [12]
EXTEND_SUBWORDS=True
MAX_SUBWORDS=1
DEBUG_OUTPUT=True
DISTRACTORS_FROM_TEXT=False
MIN_SENT_WORDS = 7

# Frequencies are used to decide if a distractor candidate might be a subword
stemmer = PorterStemmer()
freq = FreqDist(i.lower() for i in gutenberg.words())
print(freq.most_common()[:5])

words_unix = set(line.strip() for line in open('dict-unix.txt'))
words_info = set(line.strip() for line in open('dict-info.txt'))
words_small = words_unix.intersection(words_info)
words_large = words_unix.union(words_info)
f = open('profanity.json')
profanity = json.load(f)

import stanza

nlp = stanza.Pipeline(lang='en', processors='tokenize')#, model_dir='/data/ondovbd/stanza_resources')

nltk.download('punkt')
nltk_sent_toker = nltk.data.load('tokenizers/punkt/english.pickle')

def is_word(str):
    '''Check if word exists in dictionary'''
    splt = str.lower().split("'")
    if len(splt) > 2:
        return False
    elif len(splt) == 2:
        return is_word(splt[0]) and (splt[1] in ['t','nt','s','ll'])
    elif '-' in str:
        for word in str.split('-'):
            if not is_word(word):
                return False
        return True
    else:
        return str.lower() in words_unix or str.lower() in words_info

def get_emb(snt_toks, tgt_toks, layers=None):
    '''Embeds a group of subword tokens in place of a mask, using the entire
    sentence for context. Returns the average of the target token embeddings,
    which are summed over the hidden layers.

    snt_toks: the tokenized sentence, including the mask token
    tgt_toks: the tokens (subwords) to replace the mask token
    layers (optional): which hidden layers to sum (list of indices)'''
    mask_idx = snt_toks.index(toker.mask_token_id)
    snt_toks = snt_toks.copy()

    while mask_idx + len(tgt_toks)-1 >= 512:
        # Shift text by 100 words
        snt_toks = snt_toks[100:]
        mask_idx -= 100
    
    snt_toks[mask_idx:mask_idx+1] = tgt_toks
    snt_toks = snt_toks[:512]
    with T.no_grad():
        if T.cuda.is_available():
            T.tensor([snt_toks]).cuda()
            T.tensor([[1]*len(snt_toks)]).cuda()
        output = model(T.tensor([snt_toks]), T.tensor([[1]*len(snt_toks)]), output_hidden_states=True)
    layers = CONTEXTUAL_EMBEDDING_LAYERS if layers is None else layers
    output = T.stack([output.hidden_states[i] for i in layers]).sum(0).squeeze()
    # Only select the tokens that constitute the requested word
    return output[mask_idx:mask_idx+len(tgt_toks)].mean(dim=0)

def energy(ctx, scaled_dists, scaled_sims, choices, words, ans):
    
    #Calculate and add cosine similarity scores 
    '''Cost function to help choose best distractors'''
    #e = [embs[i] for i in choices] #+ [sem_emb_ans]
    #w = [words[i] for i in choices] #+ [ans]
    
    hm_sim = 0
    e_ctx = 0
    for i in choices:
        hm_sim += 1./scaled_sims[i]
        e_ctx += ctx[i]
    
    e_sim = float(len(choices))/hm_sim
    
    hm_emb = 0
    count = 0
    c = choices + [len(ctx)]
    for i in range(len(c)):
        for j in range(i):
            d = scaled_dists['%s-%s'%(max(c[i],c[j]), min(c[i], c[j]))]
            #print(c[i], c[j], d)
            hm_emb += 1./d
            count += 1
    e_emb = float(count)/hm_emb
    return float(e_emb), e_ctx, float(e_sim)

def anneal(probs_sent_context, probs_para_context, embs, emb_ans, words, k, ans):
    '''find k distractor indices that are optimally high probability and distant
    in embedding space'''
#    probs_sent_context = T.as_tensor(probs_sent_context) / sum(probs_sent_context)
    m = len(probs_sent_context)
#    probs_para_context = T.as_tensor(probs_para_context) / sum(probs_para_context)
    its = 1000
    n = len(probs_para_context)
    choices = list(range(k))
    
    dists = {}
    embsa = embs + [emb_ans]
    for i in range(len(embsa)):
        for j in range(i):
            dists['%s-%s'%(i,j)] = 1-cos(embsa[i], embsa[j]) # cosine "distance"
            #print(words[i], words[j], 1-cos(embs[i], embs[j]))
    
    dist_min = T.min(T.tensor(list(dists.values())))
    dist_max = T.max(T.tensor(list(dists.values())))
    for key, dist in dists.items():
        dists[key] = (dist - dist_min)/(dist_max-dist_min)
    
    sims = T.tensor([cos(emb_ans, emb) for emb in embs])
    scaled_sims = (sims - T.min(sims))/(T.max(sims)-T.min(sims))
    
    ctx = T.tensor(probs_sent_context).log()-ALPHA*T.tensor(probs_para_context).log()
    ctx = (ctx-T.min(ctx))/(T.max(ctx)-T.min(ctx))
    
    e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans)
    e = e_ctx + BETA * e_emb
    #e = SIM_ANNEAL_EMB_WEIGHT * e_emb + e_prob
    for i in range(its):
        t = 1.-(i)/its
        mut_idx = random.randrange(k) # which choice to mutate
        orig = choices[mut_idx]
        new = orig
        while (new in choices): # mutate choice until not in current list
            new = random.randrange(m)
        choices[mut_idx] = new
        e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans)
        e_new = e_ctx + BETA * e_emb
        delta = e_new - e
        exponent = delta/t
        if exponent < -50:
            exponent = -50 # avoid underflow
        if delta > 0 or exp(exponent) > random.random():
            e = e_new # accept new state
        else:
            choices[mut_idx] = orig
    if DEBUG_OUTPUT:
        print([words[j] for j in choices] + [ans], "e: %f"%(e))
    return choices

def get_softmax_logits(toks, n_masks = 1, sub_ids = []):
    # Tokenize text - Keep length of inpts at or below 512 (including answer token length artifically added at end)
    msk_idx = toks.index(toker.mask_token_id)
    toks = toks.copy()
    toks[msk_idx:msk_idx+1] = [toker.mask_token_id] * n_masks + sub_ids
    
    # If the masked_token is over 512 (excluding answer token length artifically added at end) tokens away
    while msk_idx >= 512:
        # Shift text by 100 words
        toks = toks[100:]
        msk_idx -= 100
    toks = toks[:512]
    # Find the predicted words for the fill-in-the-blank mask term based on sentence-context alone
    with T.no_grad():
        t=T.tensor([toks])
        m=T.tensor([[1]*len(toks)])
        if T.cuda.is_available():
            t.cuda()
            m.cuda()
        output = model(t, m)
    sm = T.softmax(output.logits[0, msk_idx:msk_idx+n_masks, :], dim=1)
    return sm

e=1e-10

def candidates(text, answer):
    '''Create list of unique distractors that does not include the actual answer'''
    if DEBUG_OUTPUT:
        print(text)
    
    # Get only sentence with blanked text to tokenize
    doc = nlp(text)
    #sents = [sentence.text for sentence in doc.sentences]
    sents = nltk_sent_toker.tokenize(text)
    msk_snt_idx = [i for i in range(len(sents)) if toker.mask_token in sents[i]][0]
    just_masked_sentence = sents[msk_snt_idx]
    
    prv_snts = sents[:msk_snt_idx]
    nxt_snts = sents[msk_snt_idx+1:]
    
    if len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and len(prv_snts):
        just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence])
    
    while len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and (len(prv_snts) or len(nxt_snts)):
        if T.rand(1) < 0.5 and len(prv_snts):
            just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence])
        elif len(nxt_snts):
            just_masked_sentence = ' '.join([just_masked_sentence, nxt_snts.pop(0)])
    
    ctx = just_masked_sentence
    while len(ctx.split(' ')) < 3 * len(just_masked_sentence.split(' ')) and (len(prv_snts) or len(nxt_snts)):
        if len(prv_snts):
            ctx = ' '.join([prv_snts.pop(), ctx])
        if len(nxt_snts):
            ctx = ' '.join([ctx, nxt_snts.pop(0)])
    
#    just_masked_sentence = ' '.join([just_masked_sentence.replace('<mask>', 'banana'),
#                                     just_masked_sentence.replace('<mask>', 'banana'),
##                                     just_masked_sentence,
  #                                   just_masked_sentence.replace('<mask>', 'banana'),
   #                                  just_masked_sentence.replace('<mask>', 'banana')])
    #just_masked_sentence = ' '.join([just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence])
    
    tiled = just_masked_sentence
    while len(tiled) < len(text):
        tiled += ' ' + just_masked_sentence
    just_masked_sentence = tiled
    
    if DEBUG_OUTPUT:
        print(ctx)
        print(text)
        print(just_masked_sentence)
    toks_para = toker.encode(text)
    toks_sent = toker.encode(just_masked_sentence)
    # Get softmaxed logits from sentence alone and full-text
#    sent_sm, sent_pos, sent_ids = get_span_logits(just_masked_sentence, answer)
#    para_sm, para_pos, para_ids = get_span_logits(text, answer)
    
    sent_sms_all = []
    para_sms_all = []
    para_sms_right = []
    
    for i in range(MAX_SUBWORDS):
        para_sms = get_softmax_logits(toks_para, i + 1)
        para_sms_all.append(para_sms)
        sent_sms = get_softmax_logits(toks_sent, i + 1)
        sent_sms_all.append(sent_sms)
        para_sms_right.append(T.exp((sent_sms[i].log()+para_sms[i].log())/2) * (suffix_mask_inv if i == 0 else suffix_mask))
    
    # Create 2 lists: (1) notes highest probability for each token across n-mask lists if token is suffix and (2) notes number of mask terms to add
    para_sm_best, para_pos_best = T.max(T.vstack(para_sms_right), 0)
    
    distractors = []
    stems = []
    embs = []
    sent_probs = []
    para_probs = []
    
    ans_stem = stemmer.stem(answer.lower())
    
    emb_ans = get_emb(toks_para, toker(answer)['input_ids'][1:-1])
    para_words = text.lower().split(' ')
    blank_word_idx = [idx for idx, word in enumerate(text.split(' ')) if toker.mask_token in word][0] # Need to remove punctuation
    if (blank_word_idx - 1) < 0:
        prev_word = 'beforeanytext'
    else:
        prev_word = para_words[blank_word_idx-1]
    if (blank_word_idx + 1) >= len(para_words):
        next_word = 'afteralltext'
    else:
        next_word = para_words[blank_word_idx+1]
    
    # Need to check if the token is outside of the tokenizer based on predictions being made at all
    if len(para_sms_all[0]) > 0:
        top_ctx = T.topk((sent_sms_all[0][0]*word_mask+e).log() - ALPHA * (para_sms_all[0][0]*word_mask+e).log(), len(para_sms_all[0][0]), dim=0)
        para_top_ids = top_ctx.indices.tolist()
        para_top_probs = top_ctx.values.tolist()
        
        for i, id in enumerate(para_top_ids):
            
            sub_ids = [int(id)] # cumulative list of subword token ids
            dec = toker.decode(sub_ids).strip()
            if DEBUG_OUTPUT:
                print('Trying:', dec)
            #print(para_pos[id])
            #if para_pos_best[id] > 0:
            #    continue
            
            if dec.isupper() != answer.isupper():
                continue
            
            if EXTEND_SUBWORDS and para_pos_best[id] > 0:
                if DEBUG_OUTPUT:
                    print("Extending %s with %d masks..."%(dec, para_pos_best[id]))
                ext_ids, _ = extend(toks_sent, toks_para, [id], para_pos_best[id], para_words)
                sub_ids = ext_ids + sub_ids
                dec_ext = toker.decode(sub_ids).strip()
                if DEBUG_OUTPUT:
                    print("Extended %s to %s"%(dec, dec_ext))
                if is_word(dec_ext) or (dec_ext != '' and dec_ext in para_words):
                    dec = dec_ext # choose new word
                else:
                    sub_ids = [int(id)] # reset
            
            if len(toker.decode(sub_ids).lower().strip()) < 2:
                continue
                
            if dec[0].isupper() != answer[0].isupper():
                continue
            
            # Only add distractor if it does not contain punctuation
            #if any(p in dec for p in punctuation):
            #    pass
                #continue
            
            if dec.lower() in profanity:
                continue
            
            # make sure is a word, either in dict or somewhere else in text
            if not is_word(dec) and dec.lower() not in para_words:
                continue

            # make sure is not the same as an adjacent word
            if dec.lower() == prev_word or dec.lower() == next_word:
                continue
            
            # Don't add the distractor if stem matches another
            stem = stemmer.stem(dec).lower()
            if stem in stems or stem == ans_stem:
                continue

            # Only add distractor if it does not contain a number
            if any(char.isdigit() for char in toker.decode([id])):
                continue

            # Only add distractor if the distractor exists in the text already
            if DISTRACTORS_FROM_TEXT and dec.lower() not in para_words:
                continue
            
            #if answer[0].isupper():
            #    dec = dec.capitalize()
            
            # PASSED ALL TESTS; finally add distractor and computations
            distractors.append(dec)
            stems.append(stem)
            sent_logprob = 0
            para_logprob = 0
            nsubs = len(sub_ids)
            for j in range(nsubs):
                sub_id = sub_ids[j]
                sent_logprob_j = log(sent_sms_all[nsubs-1][j][sub_id])
                para_logprob_j = log(para_sms_all[nsubs-1][j][sub_id])
                #if j == 0 or sent_logprob_j > sent_logprob:
                #    sent_logprob = sent_logprob_j
                #if j == 0 or para_logprob_j > para_logprob:
                #    para_logprob = para_logprob_j
                sent_logprob += sent_logprob_j
                para_logprob += para_logprob_j
            sent_logprob /= nsubs
            para_logprob /= nsubs
            if DEBUG_OUTPUT:
                print("%s (p_sent=%f, p_para=%f)"%(dec,sent_logprob,para_logprob))
            sent_probs.append(exp(sent_logprob))
            para_probs.append(exp(para_logprob))
#            sent_probs.append(sent_sms_all[nsubs-1][nsubs-1][sub_id])
#            para_probs.append(para_sms_all[nsubs-1][nsubs-1][sub_id])
            embs.append(get_emb(toks_para, sub_ids))

            if len(distractors) >= K:
                break
    if DEBUG_OUTPUT:
        print('Corresponding Text: ', text)
        print('Correct Answer: ', answer)
        print('Distractors created before annealing: ', distractors)
    #indices = anneal(sent_probs, para_probs, embs, emb_ans, number_of_distractors, distractors, answer)
    #distractors = [distractors[i] for i in indices]
    #distractors += [''] * (number_of_distractors - len(distractors))
        
    return sent_probs, para_probs, embs, emb_ans, distractors

def create_distractors(text, answer):
    sent_probs, para_probs, embs, emb_ans, distractors = candidates(text, answer)
    #print(distractors)
    indices = anneal(sent_probs, para_probs, embs, emb_ans, distractors, 3, answer)
    return [distractors[x] for x in indices]

st.title("nCloze")
st.subheader("Create a multiple-choice cloze test from a passage")
st.markdown("Note: this is a free, CPU-only space and will be slow. For better performance, clone the space with a GPU-enabled environment.")

def blank(tok):
	if tok == 'a(n)':
		strp = tok
	else:
		strp = tok.strip(punctuation)
	print(strp, tok.replace(strp, toker.mask_token))
	return strp, tok.replace(strp, toker.mask_token)

test = """In contrast to necrosis, which is a form of traumatic cell death that results from acute cellular injury, apoptosis is a highly regulated and controlled process that confers advantages during an organism's life cycle. For example, the separation of fingers and toes in a developing human embryo occurs because cells between the digits undergo apoptosis. Unlike necrosis, apoptosis produces cell fragments called apoptotic bodies that phagocytes are able to engulf and remove before the contents of the cell can spill out onto surrounding cells and cause damage to them."""
st.header("Basic options")
SPACING = int(st.text_input('Blank spacing', value="7"))
OFFSET = int(st.text_input('First word to blank (0 to use spacing)', value="4"))
st.header("Advanced options")
ALPHA = float(st.text_input('Incorrectness weight', value="0.75"))
BETA = float(st.text_input('Distinctness weight', value="0.25"))
MODEL_TYPE = st.text_input('Masked Language Model (from HuggingFace)', value="roberta-large")
K = 16

model = AutoModelForMaskedLM.from_pretrained(MODEL_TYPE)#, cache_dir=CACHE_DIR)

if T.cuda.is_available():
    model.cuda()

toker = AutoTokenizer.from_pretrained(MODEL_TYPE, add_prefix_space=True)

sorted_toker_vocab_dict = sorted(toker.vocab.items(), key=lambda x:x[1])
if toker.mask_token == '[MASK]': # BERT style
    suffix_mask = T.FloatTensor([1 if (('##' == x[0][:2]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix
else: # RoBERTa style
    suffix_mask = T.FloatTensor([1 if (('Ġ' != x[0][0]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix
suffix_mask_inv = suffix_mask * -1 + 1
word_mask = suffix_mask_inv*T.FloatTensor([1 if is_word(x[0][1:]) and x[0][1:].lower() not in profanity else 0 for x in sorted_toker_vocab_dict])
if T.cuda.is_available():
    suffix_mask=suffix_mask.cuda()
    suffix_mask_inv=suffix_mask_inv.cuda()
    word_mask = word_mask.cuda()

st.subheader("Passage")
st.text_area('Passage to create a cloze test from:',value=test,key="text", max_chars=1024, height=275)

def generate():
    ws = st.session_state.text.split()
    wb = st.session_state.text.split()
    
    qs = []
    i = OFFSET - 1 if OFFSET > 0 else SPACING - 1
    j = 0
    while i < len(ws):
        a, b = blank(ws[i])
        while b == '' and i < len(ws)-1:
            i += 1
            a, b = blank(ws[i])
        if b != '':
            w = ws[i]
            ws[i] = b
            wb[i] = b
            
            while j<i:
                yield(' ' + ws[j])
                j += 1
            masked = ' '.join(ws)
            #st.write(masked)
            ds = create_distractors(masked, a)
            print(ds, a)
            q = ds+[a+'\*']
            random.shuffle(q)
            yield(b.replace(toker.mask_token,' **['+', '.join(q)+']**'))
            j+=1
            qs.append(ds)
            ws[i] = w
        i += SPACING
    while j<len(ws):
        yield(' ' + ws[j])
        j += 1

# Load model and run inference
if st.button("Generate"):

    st.write_stream(generate())