File size: 5,517 Bytes
4028f2b
 
 
 
 
 
7805bd4
4028f2b
7805bd4
4028f2b
 
 
 
 
 
7805bd4
4028f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7805bd4
 
4028f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7805bd4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import torch
import tiktoken
from model import GPT, GPTConfig
import pickle
import string
import gradio as gr
from contextlib import nullcontext

# Model and Tokenizer setup
device = 'cpu'
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

cl100k_base = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.Encoding(
    name="cl100k_im",
    pat_str=cl100k_base._pat_str,
    mergeable_ranks=cl100k_base._mergeable_ranks,
    special_tokens={
        **cl100k_base._special_tokens,
        "<EOR>": 100264,
        "<SPECIAL2>": 100265,
    }
)

# Load model from checkpoint
model_save_path = 'latest_model.pt'
if os.path.exists(model_save_path):
    model = torch.load(model_save_path, map_location=device)
else:
    raise FileNotFoundError(f"Model file {model_save_path} not found")

model.eval()
model.to(device)
model = torch.compile(model)

# Function to encode and decode using the tokenizer
def encode(text):
    return enc.encode(text, allowed_special={"<EOR>"})

def decode(tokens):
    return enc.decode(tokens)

# Function to truncate output to token limit
def truncate_output(text, token_limit):
    tokens = text.split()
    if len(tokens) > token_limit:
        return ' '.join(tokens[:token_limit]) + '...'
    return text

def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
    while len(output.split()) < max_length and not output.endswith('.'):
        continuation = model.generate(
            torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
            max_new_tokens=max_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            eor_token_id=eor_token_id
        )
        continuation_text = decode(continuation[0].tolist())
        if eor_token_id in continuation[0].tolist():
            continuation_text = continuation_text.split("<EOR>")[0]
            output += continuation_text
            break
        else:
            output += continuation_text
            if len(output.split()) >= max_length:
                break
    return output

# Text generation function for Gradio interface
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
    with torch.no_grad():
        with ctx:
            start_ids = encode(prompt)
            initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
            outputs = []
            for _ in range(num_samples):
                y = model.generate(
                    initial_prompt,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty,
                    eor_token_id=eor_token_id
                )
                # Filter out tokens after the end-of-response token or similar markers
                output_ids = y[0].tolist()
                if eor_token_id in output_ids:
                    output_ids = output_ids[:output_ids.index(eor_token_id) + 1]  # Include EOR token
                else:
                    # Check for similar markers like '<E' and handle them
                    try:
                        eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
                        output_ids = output_ids[:eor_index]
                    except StopIteration:
                        pass

                # Ensure the prompt is not included in the final output
                output = decode(output_ids).replace(prompt, '').strip()
                output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
                truncated_output = truncate_output(output, max_new_tokens)
                outputs.append(truncated_output)
            return '\n\n'.join(outputs)

# Create a Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"),
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"),
        gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"),
        gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"),
        gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"),
        gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"),
        gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"),
        gr.inputs.Number(default=100264, label="End-of-Response Token ID")
    ],
    outputs="text",
    title="GPT Text Generator",
    description="Generate text based on a prompt using a trained GPT model.",
    examples=[
        ["Write a short story about a boy:"],
        ["Explain the theory of relativity:"],
        ["What is the meaning of life?"]
    ]
)

# Launch the Gradio app
demo.launch()