import copy import logging from typing import List import numpy as np import tensorflow as tf import streamlit as st from transformers import BertTokenizer, TFAutoModelForMaskedLM from rhyme_with_ai.utils import color_new_words, pairwise, sanitize from rhyme_with_ai.token_weighter import TokenWeighter from rhyme_with_ai.rhyme import query_rhyme_words DEFAULT_QUERY = "Machines will take over the world soon" N_RHYMES = 10 ITER_FACTOR = 5 LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0) if LANGUAGE == "english": MODEL_PATH = "bert-large-cased-whole-word-masking" elif LANGUAGE == "dutch": MODEL_PATH = "GroNLP/bert-base-dutch-cased" else: raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.") def main(): st.markdown( "Created with " "[Datamuse](https://www.datamuse.com/api/), " "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " "[Hugging Face](https://huggingface.co/), " "[Streamlit](https://streamlit.io/) and " "[App Engine](https://cloud.google.com/appengine/)." " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " "or check the " "[source](https://github.com/godatadriven/rhyme-with-ai).", unsafe_allow_html=True, ) st.title("Rhyme with AI") query = get_query() if not query: query = DEFAULT_QUERY rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) if rhyme_words_options: logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) start_rhyming(query, rhyme_words_options) else: st.write("No rhyme words found") def get_query(): q = sanitize( st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) ) if not q: return DEFAULT_QUERY return q def start_rhyming(query, rhyme_words_options): st.markdown("## My Suggestions:") progress_bar = st.progress(0) status_text = st.empty() max_iter = len(query.split()) * ITER_FACTOR rhyme_words = rhyme_words_options[:N_RHYMES] model, tokenizer = load_model(MODEL_PATH) sentence_generator = RhymeGenerator(model, tokenizer) sentence_generator.start(query, rhyme_words) current_sentences = [" " for _ in range(N_RHYMES)] for i in range(max_iter): previous_sentences = copy.deepcopy(current_sentences) current_sentences = sentence_generator.mutate() display_output(status_text, query, current_sentences, previous_sentences) progress_bar.progress(i / (max_iter - 1)) st.balloons() @st.cache(allow_output_mutation=True) def load_model(model_path): return ( TFAutoModelForMaskedLM.from_pretrained(model_path), BertTokenizer.from_pretrained(model_path), ) def display_output(status_text, query, current_sentences, previous_sentences): print_sentences = [] for new, old in zip(current_sentences, previous_sentences): formatted = color_new_words(new, old) after_comma = "
  • " + formatted.split(",")[1][:-2] + "
  • " print_sentences.append(after_comma) status_text.markdown( query + ",
    " + "".join(print_sentences), unsafe_allow_html=True ) class RhymeGenerator: def __init__( self, model: TFAutoModelForMaskedLM, tokenizer: BertTokenizer, token_weighter: TokenWeighter = None, ): """Generate rhymes. Parameters ---------- model : Model for masked language modelling tokenizer : Tokenizer for model token_weighter : Class that weighs tokens """ self.model = model self.tokenizer = tokenizer if token_weighter is None: token_weighter = TokenWeighter(tokenizer) self.token_weighter = token_weighter self._logger = logging.getLogger(__name__) self.tokenized_rhymes_ = None self.position_probas_ = None # Easy access. self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0] self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0] self.mask_token_id = self.tokenizer.mask_token_id def start(self, query: str, rhyme_words: List[str]) -> None: """Start the sentence generator. Parameters ---------- query : Seed sentence rhyme_words : Rhyme words for next sentence """ # TODO: What if no content? self._logger.info("Got sentence %s", query) tokenized_rhymes = [ self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words ] # Make same length. self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences( tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id ) p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id self.position_probas_ = p / p.sum(1).reshape(-1, 1) def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]: """Initialize the rhymes. * Tokenize input * Append a comma if the sentence does not end in it (might add better predictions as it shows the two sentence parts are related) * Make second line as long as the original * Add a period Parameters ---------- query : First line rhyme_word : Last word for second line Returns ------- Tokenized rhyme lines """ query_token_ids = self.tokenizer.encode(query, add_special_tokens=False) rhyme_word_token_ids = self.tokenizer.encode( rhyme_word, add_special_tokens=False ) if query_token_ids[-1] != self.comma_token_id: query_token_ids.append(self.comma_token_id) magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma return ( query_token_ids + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction) + rhyme_word_token_ids + [self.period_token_id] ) def mutate(self): """Mutate the current rhymes. Returns ------- Mutated rhymes """ self.tokenized_rhymes_ = self._mutate( self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba ) rhymes = [] for i in range(len(self.tokenized_rhymes_)): rhymes.append( self.tokenizer.convert_tokens_to_string( self.tokenizer.convert_ids_to_tokens( self.tokenized_rhymes_[i], skip_special_tokens=True ) ) ) return rhymes def _mutate( self, tokenized_rhymes: np.ndarray, position_probas: np.ndarray, token_id_probas: np.ndarray, ) -> np.ndarray: replacements = [] for i in range(tokenized_rhymes.shape[0]): mask_idx, masked_token_ids = self._mask_token( tokenized_rhymes[i], position_probas[i] ) tokenized_rhymes[i] = masked_token_ids replacements.append(mask_idx) predictions = self._predict_masked_tokens(tokenized_rhymes) for i, token_ids in enumerate(tokenized_rhymes): replace_ix = replacements[i] token_ids[replace_ix] = self._draw_replacement( predictions[i], token_id_probas, replace_ix ) tokenized_rhymes[i] = token_ids return tokenized_rhymes def _mask_token(self, token_ids, position_probas): """Mask line and return index to update.""" token_ids = self._mask_repeats(token_ids, position_probas) ix = self._locate_mask(token_ids, position_probas) token_ids[ix] = self.mask_token_id return ix, token_ids def _locate_mask(self, token_ids, position_probas): """Update masks or a random token.""" if self.mask_token_id in token_ids: # Already masks present, just return the last. # We used to return thee first but this returns worse predictions. return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1] return np.random.choice(range(len(position_probas)), p=position_probas) def _mask_repeats(self, token_ids, position_probas): """Repeated tokens are generally of less quality.""" repeats = [ ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1] ] for ii in repeats: if position_probas[ii] > 0: token_ids[ii] = self.mask_token_id if position_probas[ii + 1] > 0: token_ids[ii + 1] = self.mask_token_id return token_ids def _predict_masked_tokens(self, tokenized_rhymes): return self.model(tf.constant(tokenized_rhymes))[0] def _draw_replacement(self, predictions, token_probas, replace_ix): """Get probability, weigh and draw.""" # TODO (HG): Can't we softmax when calling the model? probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas probas /= probas.sum() return np.random.choice(range(len(probas)), p=probas) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main()