File size: 5,594 Bytes
363236f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-

import argparse
import re
import os

import streamlit as st
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import tokenizers

#os.environ["TOKENIZERS_PARALLELISM"] = "false"

random.seed(None)
suggested_text_list = ['ื”ืฉื“ ื”ื•ืคื™ืข ืžื•ืœ','ืงืืœื™ ืฉืœืคื” ืืช','ืคืขื ืื—ืช ืœืคื ื™ ืฉื ื™ื ืจื‘ื•ืช', 'ื”ืืจื™ ืคื•ื˜ืจ ื—ื™ื™ืš ื—ื™ื•ืš ื ื‘ื•ืš', 'ื•ืื– ื”ืคืจืชื™ ืืช ื›ืœ ื›ืœืœื™ ื”ื˜ืงืก ื›ืฉ']

@st.cache(hash_funcs={tokenizers.Tokenizer: id, tokenizers.AddedToken: id})
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def extend(input_text, max_size=20, top_k=50, top_p=0.95):
    if len(input_text) == 0:
        input_text = ""

    encoded_prompt = tokenizer.encode(
    input_text, add_special_tokens=False, return_tensors="pt")

    encoded_prompt = encoded_prompt.to(device)

    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    
    output_sequences = model.generate(
    input_ids=input_ids,
    max_length=max_size + len(encoded_prompt[0]),
    top_k=top_k, 
    top_p=top_p, 
    do_sample=True,
    repetition_penalty=5.0,
    num_return_sequences=1)

    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):        
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(stop_token) if stop_token else None]

        # Remove all text after 3 newlines
        text = text[: text.find(new_lines) if new_lines else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            input_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )

        generated_sequences.append(total_sequence)
    
    parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
    if len(parsed_text) == 0:
        parsed_text = "ืฉื’ื™ืื”"
    return parsed_text



if __name__ == "__main__":
    st.title("Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)")
    model, tokenizer = load_model("./model")
    
    stop_token = "<|endoftext|>"
    new_lines = "\n\n\n"

    np.random.seed(None)
    random_seed = np.random.randint(10000,size=1)    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()

    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    model.to(device)

    text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", 'ื”ืื“ื ื”ืื—ืจื•ืŸ ืขืœื™ ืื“ืžื•ืช ื™ืฉื‘ ืœื‘ื“ ื‘ื—ื“ืจื• ื›ืฉืœืคืชืข ื ืฉืžืขื” ื“ืคื™ืงื”')

    st.sidebar.subheader("Configurable parameters")

    max_len = st.sidebar.slider("Max-Length", 0, 512, 160,help="The maximum length of the sequence to be generated.")
    top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
    top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")

    if st.button("Generate Text"):
        with st.spinner(text="Generating results..."):
            st.subheader("Result")
            print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
            if len(text_area.strip()) == 0:
                text_area = random.choice(suggested_text_list)
            result = extend(input_text=text_area,                         
                            max_size=int(max_len),                         
                            top_k=int(top_k),
                            top_p=float(top_p))

            print("Done length: " + str(len(result)) + " bytes") 
            #<div class="rtl" dir="rtl" style="text-align:right;">
            st.markdown(f"<p dir=\"rtl\" style=\"text-align:right;\"> {result} </p>", unsafe_allow_html=True)
            st.write("\n\nResult length: " + str(len(result)) + " bytes\n Random seed: " + str(random_seed) + "\ntop_k: " + str(top_k) + "\ntop_p: " + str(top_p) + "\nmax_len: " + str(max_len) + "\ndevice: " + str(device) + "\nn_gpu: " + str(n_gpu))
            print(f"\"{result}\"")      
    
    st.markdown(    
        """Hebrew text generation model based on EleutherAI's gpt-neo architecture. Originally trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/). The model was then slightly fine-tuned upon science fiction and fantasy text."""
    )

    st.markdown("<footer><hr><p style=\"font-size:14px\">The site is fan made and is not affiliated with any author in any way.</p><p style=\"font-size:12px\">By <a href=\"https://linktr.ee/Norod78\">Doron Adler</a></p></footer> ", unsafe_allow_html=True)