Spaces:
Sleeping
Sleeping
File size: 6,541 Bytes
f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e 0880270 87596b4 f78244e 6a843e3 f78244e 88ec4e4 f78244e c1f4179 88ec4e4 f78244e a1028e2 f78244e 74e144a ab26b95 a4645a6 f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e 92627cb f78244e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import random
import re
from poems import SAMPLE_POEMS
import langid
import numpy as np
import streamlit as st
import torch
from icu_tokenizer import Tokenizer
from transformers import pipeline
MODELS = {
"ALBERTI": "flax-community/alberti-bert-base-multilingual-cased",
"mBERT": "bert-base-multilingual-cased"
}
TOPK = 50
st.set_page_config(layout="wide")
def mask_line(line, language="es", restrictive=True):
tokenizer = Tokenizer(lang=language)
token_list = tokenizer.tokenize(line)
if lang != "zh":
restrictive = not all([len(token) <= 3 for token in token_list])
random_num = random.randint(0, len(token_list) - 1)
random_word = token_list[random_num]
if not restrictive:
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
else:
return mask_line(line, language)
def filter_candidates(candidates, get_any_candidate=False):
cand_list = []
score_list = []
for candidate in candidates:
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
elif get_any_candidate:
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
if len(score_list) == TOPK:
break
if len(cand_list) < 1:
return filter_candidates(candidates, get_any_candidate=True)
else:
return cand_list[0]
def infer_candidates(nlp, line):
line = re.sub("β", "-", line)
line = re.sub("β", "-", line)
line = re.sub("β", "'", line)
line = re.sub("β¦", "...", line)
inputs = nlp._parse_and_tokenize(line)
outputs = nlp._forward(inputs, return_tensors=True)
input_ids = inputs["input_ids"][0]
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
as_tuple=False)
logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(TOPK)
result = []
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
l = []
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
for idx, token in enumerate(token_list):
if token.startswith('##'):
l[-1] += token[2:]
elif idx == masked_index.item():
l += ['<b style="color: #ff0000;">', token, "</b>"]
else:
l += [token]
sequence = " ".join(l).strip()
result.append(
{
"sequence": sequence,
"score": v,
"token": p,
"token_str": nlp.tokenizer.decode(p),
"masked_index": masked_index.item()
}
)
return result
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
nlp = pipeline("fill-mask", model=ml_model)
unmasked_lines = []
masked_lines = []
for line in poem:
if line == "":
unmasked_lines.append("")
masked_lines.append("")
continue
if masking:
masked_line = mask_line(line, language)
else:
masked_line = line
masked_lines.append(masked_line)
unmasked_line_candidates = infer_candidates(nlp, masked_line)
unmasked_line = filter_candidates(unmasked_line_candidates)
unmasked_lines.append(unmasked_line)
unmasked_poem = "<br>".join(unmasked_lines)
return unmasked_poem, masked_lines
instructions_text_0 = st.sidebar.markdown(
"""# ALBERTI vs BERT π₯
We present ALBERTI, our BERT-based multilingual model for poetry.""")
instructions_text_1 = st.sidebar.markdown(
"""We have trained bert on a huge (for poetry, that is) corpus of
multilingual poetry to try to get a more 'poetic' model. This is the result
of our work.
You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")
sample_chooser = st.sidebar.selectbox(
"Choose a poem",
list(SAMPLE_POEMS.keys())
)
instructions_text_2 = st.sidebar.markdown("""# How to use
You can choose from a list of example poems in Spanish, English, French, German,
Chinese and Arabic, but you can also paste a poem, or write it yourself!
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
poem, randomly masking one word per verse, and get the two new versions for each of the models.
The list of languages used on the training of ALBERTI are:
* Arabic
* Chinese
* Czech
* English
* Finnish
* French
* German
* Hungarian
* Italian
* Portuguese
* Russian
* Spanish""")
col1, col2, col3 = st.columns(3)
st.markdown(
"""
<style>
label {
font-size: 1rem !important;
font-weight: bold !important;
}
.block-container {
padding-left: 1rem !important;
padding-right: 1rem !important;
}
</style>
""", unsafe_allow_html=True)
if sample_chooser:
model_list = set(MODELS.values())
user_input = col1.text_area("Input poem",
"\n".join(SAMPLE_POEMS[sample_chooser]),
height=600)
poem = user_input.split("\n")
rewrite_button = col1.button("Rewrite!")
if "[MASK]" in user_input or "<mask>" in user_input:
col1.error("You don't have to mask the poem, we'll do it for you!")
if rewrite_button:
lang = langid.classify(user_input)[0]
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>
{unmasked_poem}""", unsafe_allow_html=True)
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
masking=False)
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>
{unmasked_poem_2}""", unsafe_allow_html=True)
|