saintyboy commited on
Commit
4028f2b
1 Parent(s): 16d99c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -7
app.py CHANGED
@@ -1,15 +1,136 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Load the sentiment analysis model
5
- classifier = pipeline('sentiment-analysis')
 
 
 
 
6
 
7
- def sentiment_analysis(text):
8
- result = classifier(text)
9
- return result[0]['label'], result[0]['score']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Create a Gradio interface
12
- demo = gr.Interface(fn=sentiment_analysis, inputs="text", outputs=["text", "number"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Launch the Gradio app
15
  demo.launch()
 
1
+ import os
2
+ import torch
3
+ import tiktoken
4
+ from model import GPT, GPTConfig
5
+ import pickle
6
+ import string
7
  import gradio as gr
8
+ from contextlib import nullcontext
9
 
10
+ # Model and Tokenizer setup
11
+ device = 'cpu'
12
+ dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
13
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
14
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
15
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
16
 
17
+ cl100k_base = tiktoken.get_encoding("cl100k_base")
18
+ enc = tiktoken.Encoding(
19
+ name="cl100k_im",
20
+ pat_str=cl100k_base._pat_str,
21
+ mergeable_ranks=cl100k_base._mergeable_ranks,
22
+ special_tokens={
23
+ **cl100k_base._special_tokens,
24
+ "<EOR>": 100264,
25
+ "<SPECIAL2>": 100265,
26
+ }
27
+ )
28
+
29
+ # Load model from checkpoint
30
+ model_save_path = 'latest_model.pt'
31
+ if os.path.exists(model_save_path):
32
+ model = torch.load(model_save_path, map_location=device)
33
+ else:
34
+ raise FileNotFoundError(f"Model file {model_save_path} not found")
35
+
36
+ model.eval()
37
+ model.to(device)
38
+ model = torch.compile(model)
39
+
40
+ # Function to encode and decode using the tokenizer
41
+ def encode(text):
42
+ return enc.encode(text, allowed_special={"<EOR>"})
43
+
44
+ def decode(tokens):
45
+ return enc.decode(tokens)
46
+
47
+ # Function to truncate output to token limit
48
+ def truncate_output(text, token_limit):
49
+ tokens = text.split()
50
+ if len(tokens) > token_limit:
51
+ return ' '.join(tokens[:token_limit]) + '...'
52
+ return text
53
+
54
+ def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
55
+ while len(output.split()) < max_length and not output.endswith('.'):
56
+ continuation = model.generate(
57
+ torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
58
+ max_new_tokens=max_length,
59
+ temperature=temperature,
60
+ top_k=top_k,
61
+ top_p=top_p,
62
+ repetition_penalty=repetition_penalty,
63
+ eor_token_id=eor_token_id
64
+ )
65
+ continuation_text = decode(continuation[0].tolist())
66
+ if eor_token_id in continuation[0].tolist():
67
+ continuation_text = continuation_text.split("<EOR>")[0]
68
+ output += continuation_text
69
+ break
70
+ else:
71
+ output += continuation_text
72
+ if len(output.split()) >= max_length:
73
+ break
74
+ return output
75
+
76
+ # Text generation function for Gradio interface
77
+ def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
78
+ with torch.no_grad():
79
+ with ctx:
80
+ start_ids = encode(prompt)
81
+ initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
82
+ outputs = []
83
+ for _ in range(num_samples):
84
+ y = model.generate(
85
+ initial_prompt,
86
+ max_new_tokens=max_new_tokens,
87
+ temperature=temperature,
88
+ top_k=top_k,
89
+ top_p=top_p,
90
+ repetition_penalty=repetition_penalty,
91
+ eor_token_id=eor_token_id
92
+ )
93
+ # Filter out tokens after the end-of-response token or similar markers
94
+ output_ids = y[0].tolist()
95
+ if eor_token_id in output_ids:
96
+ output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
97
+ else:
98
+ # Check for similar markers like '<E' and handle them
99
+ try:
100
+ eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
101
+ output_ids = output_ids[:eor_index]
102
+ except StopIteration:
103
+ pass
104
+
105
+ # Ensure the prompt is not included in the final output
106
+ output = decode(output_ids).replace(prompt, '').strip()
107
+ output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
108
+ truncated_output = truncate_output(output, max_new_tokens)
109
+ outputs.append(truncated_output)
110
+ return '\n\n'.join(outputs)
111
 
112
  # Create a Gradio interface
113
+ demo = gr.Interface(
114
+ fn=generate_text,
115
+ inputs=[
116
+ gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"),
117
+ gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"),
118
+ gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"),
119
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"),
120
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"),
121
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"),
122
+ gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"),
123
+ gr.inputs.Number(default=100264, label="End-of-Response Token ID")
124
+ ],
125
+ outputs="text",
126
+ title="GPT Text Generator",
127
+ description="Generate text based on a prompt using a trained GPT model.",
128
+ examples=[
129
+ ["Write a short story about a boy:"],
130
+ ["Explain the theory of relativity:"],
131
+ ["What is the meaning of life?"]
132
+ ]
133
+ )
134
 
135
  # Launch the Gradio app
136
  demo.launch()