Spaces:
Build error
Build error
import copy | |
import functools | |
import itertools | |
import logging | |
import random | |
import string | |
from typing import List, Optional | |
import requests | |
import numpy as np | |
import tensorflow as tf | |
import streamlit as st | |
from gazpacho import Soup, get | |
from transformers import BertTokenizer, TFAutoModelForMaskedLM | |
from rhyme_with_ai.utils import color_new_words, pairwise, find_last_word, sanitize | |
DEFAULT_QUERY = "Machines will take over the world soon" | |
N_RHYMES = 10 | |
ITER_FACTOR = 5 | |
LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0) | |
if LANGUAGE == "english": | |
MODEL_PATH = "bert-large-cased-whole-word-masking" | |
elif LANGUAGE == "dutch": | |
MODEL_PATH = "GroNLP/bert-base-dutch-cased" | |
else: | |
raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.") | |
def main(): | |
st.markdown( | |
"<sup>Created with " | |
"[Datamuse](https://www.datamuse.com/api/), " | |
"[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " | |
"[Hugging Face](https://huggingface.co/), " | |
"[Streamlit](https://streamlit.io/) and " | |
"[App Engine](https://cloud.google.com/appengine/)." | |
" Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " | |
"or check the " | |
"[source](https://github.com/godatadriven/rhyme-with-ai).</sup>", | |
unsafe_allow_html=True, | |
) | |
st.title("Rhyme with AI") | |
query = get_query() | |
if not query: | |
query = DEFAULT_QUERY | |
rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) | |
if rhyme_words_options: | |
logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) | |
start_rhyming(query, rhyme_words_options) | |
else: | |
st.write("No rhyme words found") | |
def get_query(): | |
q = sanitize( | |
st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) | |
) | |
if not q: | |
return DEFAULT_QUERY | |
return q | |
def start_rhyming(query, rhyme_words_options): | |
st.markdown("## My Suggestions:") | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
max_iter = len(query.split()) * ITER_FACTOR | |
rhyme_words = rhyme_words_options[:N_RHYMES] | |
model, tokenizer = load_model(MODEL_PATH) | |
sentence_generator = RhymeGenerator(model, tokenizer) | |
sentence_generator.start(query, rhyme_words) | |
current_sentences = [" " for _ in range(N_RHYMES)] | |
for i in range(max_iter): | |
previous_sentences = copy.deepcopy(current_sentences) | |
current_sentences = sentence_generator.mutate() | |
display_output(status_text, query, current_sentences, previous_sentences) | |
progress_bar.progress(i / (max_iter - 1)) | |
st.balloons() | |
def load_model(model_path): | |
return ( | |
TFAutoModelForMaskedLM.from_pretrained(model_path), | |
BertTokenizer.from_pretrained(model_path), | |
) | |
def display_output(status_text, query, current_sentences, previous_sentences): | |
print_sentences = [] | |
for new, old in zip(current_sentences, previous_sentences): | |
formatted = color_new_words(new, old) | |
after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>" | |
print_sentences.append(after_comma) | |
status_text.markdown( | |
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True | |
) | |
class TokenWeighter: | |
def __init__(self, tokenizer): | |
self.tokenizer_ = tokenizer | |
self.proba = self.get_token_proba() | |
def get_token_proba(self): | |
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab) | |
return valid_token_mask | |
def _filter_short_partial(self, vocab): | |
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k] | |
is_valid = np.zeros(len(vocab.keys())) | |
is_valid[valid_token_ids] = 1 | |
return is_valid | |
class RhymeGenerator: | |
def __init__( | |
self, | |
model: TFAutoModelForMaskedLM, | |
tokenizer: BertTokenizer, | |
token_weighter: TokenWeighter = None, | |
): | |
"""Generate rhymes. | |
Parameters | |
---------- | |
model : Model for masked language modelling | |
tokenizer : Tokenizer for model | |
token_weighter : Class that weighs tokens | |
""" | |
self.model = model | |
self.tokenizer = tokenizer | |
if token_weighter is None: | |
token_weighter = TokenWeighter(tokenizer) | |
self.token_weighter = token_weighter | |
self._logger = logging.getLogger(__name__) | |
self.tokenized_rhymes_ = None | |
self.position_probas_ = None | |
# Easy access. | |
self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0] | |
self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0] | |
self.mask_token_id = self.tokenizer.mask_token_id | |
def start(self, query: str, rhyme_words: List[str]) -> None: | |
"""Start the sentence generator. | |
Parameters | |
---------- | |
query : Seed sentence | |
rhyme_words : Rhyme words for next sentence | |
""" | |
# TODO: What if no content? | |
self._logger.info("Got sentence %s", query) | |
tokenized_rhymes = [ | |
self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words | |
] | |
# Make same length. | |
self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences( | |
tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id | |
) | |
p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id | |
self.position_probas_ = p / p.sum(1).reshape(-1, 1) | |
def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]: | |
"""Initialize the rhymes. | |
* Tokenize input | |
* Append a comma if the sentence does not end in it (might add better predictions as it | |
shows the two sentence parts are related) | |
* Make second line as long as the original | |
* Add a period | |
Parameters | |
---------- | |
query : First line | |
rhyme_word : Last word for second line | |
Returns | |
------- | |
Tokenized rhyme lines | |
""" | |
query_token_ids = self.tokenizer.encode(query, add_special_tokens=False) | |
rhyme_word_token_ids = self.tokenizer.encode( | |
rhyme_word, add_special_tokens=False | |
) | |
if query_token_ids[-1] != self.comma_token_id: | |
query_token_ids.append(self.comma_token_id) | |
magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma | |
return ( | |
query_token_ids | |
+ [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction) | |
+ rhyme_word_token_ids | |
+ [self.period_token_id] | |
) | |
def mutate(self): | |
"""Mutate the current rhymes. | |
Returns | |
------- | |
Mutated rhymes | |
""" | |
self.tokenized_rhymes_ = self._mutate( | |
self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba | |
) | |
rhymes = [] | |
for i in range(len(self.tokenized_rhymes_)): | |
rhymes.append( | |
self.tokenizer.convert_tokens_to_string( | |
self.tokenizer.convert_ids_to_tokens( | |
self.tokenized_rhymes_[i], skip_special_tokens=True | |
) | |
) | |
) | |
return rhymes | |
def _mutate( | |
self, | |
tokenized_rhymes: np.ndarray, | |
position_probas: np.ndarray, | |
token_id_probas: np.ndarray, | |
) -> np.ndarray: | |
replacements = [] | |
for i in range(tokenized_rhymes.shape[0]): | |
mask_idx, masked_token_ids = self._mask_token( | |
tokenized_rhymes[i], position_probas[i] | |
) | |
tokenized_rhymes[i] = masked_token_ids | |
replacements.append(mask_idx) | |
predictions = self._predict_masked_tokens(tokenized_rhymes) | |
for i, token_ids in enumerate(tokenized_rhymes): | |
replace_ix = replacements[i] | |
token_ids[replace_ix] = self._draw_replacement( | |
predictions[i], token_id_probas, replace_ix | |
) | |
tokenized_rhymes[i] = token_ids | |
return tokenized_rhymes | |
def _mask_token(self, token_ids, position_probas): | |
"""Mask line and return index to update.""" | |
token_ids = self._mask_repeats(token_ids, position_probas) | |
ix = self._locate_mask(token_ids, position_probas) | |
token_ids[ix] = self.mask_token_id | |
return ix, token_ids | |
def _locate_mask(self, token_ids, position_probas): | |
"""Update masks or a random token.""" | |
if self.mask_token_id in token_ids: | |
# Already masks present, just return the last. | |
# We used to return thee first but this returns worse predictions. | |
return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1] | |
return np.random.choice(range(len(position_probas)), p=position_probas) | |
def _mask_repeats(self, token_ids, position_probas): | |
"""Repeated tokens are generally of less quality.""" | |
repeats = [ | |
ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1] | |
] | |
for ii in repeats: | |
if position_probas[ii] > 0: | |
token_ids[ii] = self.mask_token_id | |
if position_probas[ii + 1] > 0: | |
token_ids[ii + 1] = self.mask_token_id | |
return token_ids | |
def _predict_masked_tokens(self, tokenized_rhymes): | |
return self.model(tf.constant(tokenized_rhymes))[0] | |
def _draw_replacement(self, predictions, token_probas, replace_ix): | |
"""Get probability, weigh and draw.""" | |
# TODO (HG): Can't we softmax when calling the model? | |
probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas | |
probas /= probas.sum() | |
return np.random.choice(range(len(probas)), p=probas) | |
def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]: | |
"""Returns a list of rhyme words for a sentence. | |
Parameters | |
---------- | |
sentence : Sentence that may end with punctuation | |
n_rhymes : Maximum number of rhymes to return | |
Returns | |
------- | |
List[str] -- List of words that rhyme with the final word | |
""" | |
last_word = find_last_word(sentence) | |
if language == "english": | |
return query_datamuse_api(last_word, n_rhymes) | |
elif language == "dutch": | |
return mick_rijmwoordenboek(last_word, n_rhymes) | |
else: | |
raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.") | |
def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]: | |
"""Query the DataMuse API. | |
Parameters | |
---------- | |
word : Word to rhyme with | |
n_rhymes : Max rhymes to return | |
Returns | |
------- | |
Rhyme words | |
""" | |
out = requests.get( | |
"https://api.datamuse.com/words", params={"rel_rhy": word} | |
).json() | |
words = [_["word"] for _ in out] | |
if n_rhymes is None: | |
return words | |
return words[:n_rhymes] | |
def mick_rijmwoordenboek(word: str, n_words: int): | |
url = f"https://rijmwoordenboek.nl/rijm/{word}" | |
html = get(url) | |
soup = Soup(html) | |
results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br>") | |
# clean up | |
results = [r.replace("\n", "").replace(" ", "") for r in results] | |
# filter html and empty strings | |
results = [r for r in results if ("<" not in r) and (len(r) > 0)] | |
return random.sample(results, min(len(results), n_words)) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
main() | |