File size: 2,907 Bytes
a5f152a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import logging

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

#model_id = "./TinyStories-3M-val-Hebrew"
model_id = "Norod78/TinyStories-3M-val-Hebrew"

tokenizer = AutoTokenizer.from_pretrained(model_id)
#model = AutoModelForCausalLM.from_pretrained("./Hebrew_GPT3_XL", from_tf=True)
model = AutoModelForCausalLM.from_pretrained(model_id)

#prompt_text = "אתמול, בדרך הביתה, גיליתי ש"
#prompt_text = "פעם, לפני ש"
#prompt_text = "הסוד השמור ביותר של תעשיית היופי"
#prompt_text = "<|startoftext|>"
prompt_text = "\n"
stop_token = "<|endoftext|>"
new_lines = "\n\n\n"
seed = 1000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()

logger.info(f"device: {device}, n_gpu: {n_gpu}")

np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

model.to(device)
#model.half()

def process_output_sequences(output_sequences):
    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    #generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
        generated_sequence = generated_sequence.tolist()
        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        text = text.replace("<|startoftext|>","").replace(" ; ", "\n")
        # Remove all text after the stop token
        text = text[: text.find(stop_token) if stop_token else None]
        # Remove all text after 3 newlines
        text = text[: text.find(new_lines) if new_lines else None]
        print(text)
        #generated_sequences.append(text)
    #print(generated_sequences)        
    print("------")


def encode_prompt(text):
    encoded_prompt = tokenizer.encode(
        text, add_special_tokens=True, return_tensors="pt")
    encoded_prompt = encoded_prompt.to(device)
    if encoded_prompt.size()[-1] == 0:
            input_ids = None
    else:
            input_ids = encoded_prompt
    return input_ids

input_ids = encode_prompt(prompt_text)
input_ids_len = input_ids.size()[-1]
max_length = input_ids_len + 192
if max_length > 1023:
    max_length = 1023

output_sequences = model.generate(
    input_ids=input_ids,
    max_length=max_length,
    temperature=0.98,
    top_k=40,
    top_p=0.92, 
    repetition_penalty=2.0,
    do_sample=True,
    num_return_sequences=5
)

process_output_sequences(output_sequences)