awacke1 commited on
Commit
7c80317
1 Parent(s): 24a40ca

Create new file

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import List
4
+
5
+ import streamlit as st
6
+ from transformers import BertTokenizer, TFAutoModelForMaskedLM
7
+
8
+ from rhyme_with_ai.utils import color_new_words, sanitize
9
+ from rhyme_with_ai.rhyme import query_rhyme_words
10
+ from rhyme_with_ai.rhyme_generator import RhymeGenerator
11
+
12
+
13
+ DEFAULT_QUERY = "Machines will take over the world soon"
14
+ N_RHYMES = 10
15
+
16
+
17
+ LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0)
18
+ if LANGUAGE == "english":
19
+ MODEL_PATH = "bert-large-cased-whole-word-masking"
20
+ ITER_FACTOR = 5
21
+ elif LANGUAGE == "dutch":
22
+ MODEL_PATH = "GroNLP/bert-base-dutch-cased"
23
+ ITER_FACTOR = 10 # Faster model
24
+ else:
25
+ raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.")
26
+
27
+ def main():
28
+ st.markdown(
29
+ "<sup>Created with "
30
+ "[Datamuse](https://www.datamuse.com/api/), "
31
+ "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), "
32
+ "[Hugging Face](https://huggingface.co/), "
33
+ "[Streamlit](https://streamlit.io/) and "
34
+ "[App Engine](https://cloud.google.com/appengine/)."
35
+ " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
36
+ "or check the "
37
+ "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
38
+ unsafe_allow_html=True,
39
+ )
40
+ st.title("Rhyme with AI")
41
+ query = get_query()
42
+ if not query:
43
+ query = DEFAULT_QUERY
44
+ rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
45
+ if rhyme_words_options:
46
+ logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options)
47
+ start_rhyming(query, rhyme_words_options)
48
+ else:
49
+ st.write("No rhyme words found")
50
+
51
+
52
+ def get_query():
53
+ q = sanitize(
54
+ st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
55
+ )
56
+ if not q:
57
+ return DEFAULT_QUERY
58
+ return q
59
+
60
+
61
+ def start_rhyming(query, rhyme_words_options):
62
+ st.markdown("## My Suggestions:")
63
+
64
+ progress_bar = st.progress(0)
65
+ status_text = st.empty()
66
+ max_iter = len(query.split()) * ITER_FACTOR
67
+
68
+ rhyme_words = rhyme_words_options[:N_RHYMES]
69
+
70
+ model, tokenizer = load_model(MODEL_PATH)
71
+ sentence_generator = RhymeGenerator(model, tokenizer)
72
+ sentence_generator.start(query, rhyme_words)
73
+
74
+ current_sentences = [" " for _ in range(N_RHYMES)]
75
+ for i in range(max_iter):
76
+ previous_sentences = copy.deepcopy(current_sentences)
77
+ current_sentences = sentence_generator.mutate()
78
+ display_output(status_text, query, current_sentences, previous_sentences)
79
+ progress_bar.progress(i / (max_iter - 1))
80
+ st.balloons()
81
+
82
+
83
+ @st.cache(allow_output_mutation=True)
84
+ def load_model(model_path):
85
+ return (
86
+ TFAutoModelForMaskedLM.from_pretrained(model_path),
87
+ BertTokenizer.from_pretrained(model_path),
88
+ )
89
+
90
+
91
+ def display_output(status_text, query, current_sentences, previous_sentences):
92
+ print_sentences = []
93
+ for new, old in zip(current_sentences, previous_sentences):
94
+ formatted = color_new_words(new, old)
95
+ after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
96
+ print_sentences.append(after_comma)
97
+ status_text.markdown(
98
+ query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
99
+ )
100
+
101
+
102
+
103
+ if __name__ == "__main__":
104
+ logging.basicConfig(level=logging.INFO)
105
+ main()