sagar007 commited on
Commit
02cf0bb
·
verified ·
1 Parent(s): 1ffcf64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -31
app.py CHANGED
@@ -1,43 +1,48 @@
 
1
  import torch
2
  import gradio as gr
3
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
-
5
- # Load the tokenizer and the model
6
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
7
- model = GPT2LMHeadModel.from_pretrained('gpt2')
8
 
9
- # Load the best model weights
10
- model.load_state_dict(torch.load('GPT_model.pth', map_location=torch.device('cpu')))
 
 
 
 
 
11
 
12
- # Set the model to evaluation mode
13
- model.eval()
14
 
15
- # Define the text generation function
16
- def generate_text(prompt, max_length=50, num_return_sequences=1):
17
- inputs = tokenizer(prompt, return_tensors='pt')
18
- outputs = model.generate(
19
- inputs.input_ids,
20
- max_length=max_length,
21
- num_return_sequences=num_return_sequences,
22
- do_sample=True,
23
- top_k=50,
24
- top_p=0.95,
25
- temperature=1.0
26
- )
27
- return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
 
28
 
29
- # Define the Gradio interface
30
- interface = gr.Interface(
31
  fn=generate_text,
32
  inputs=[
33
- gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here..."),
34
- gr.inputs.Slider(minimum=10, maximum=200, default=50, label="Max Length"),
35
- gr.inputs.Slider(minimum=1, maximum=5, default=1, label="Number of Sequences")
36
  ],
37
- outputs=gr.outputs.Textbox(),
38
  title="GPT-2 Text Generator",
39
- description="Enter a prompt to generate text using GPT-2.",
40
  )
41
 
42
- # Launch the Gradio interface
43
- interface.launch()
 
1
+
2
  import torch
3
  import gradio as gr
4
+ from model import GPT, GPTConfig # Assuming your model code is in a file named model.py
5
+ import tiktoken
 
 
 
6
 
7
+ # Load the trained model
8
+ def load_model(model_path):
9
+ config = GPTConfig() # Adjust this if you've changed the default config
10
+ model = GPT(config)
11
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
12
+ model.eval()
13
+ return model
14
 
15
+ model = load_model('GPT_model.pth') # Replace with the actual path to your .pth file
16
+ enc = tiktoken.get_encoding('gpt2')
17
 
18
+ def generate_text(prompt, max_length=100, temperature=0.7):
19
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
20
+
21
+ with torch.no_grad():
22
+ for _ in range(max_length):
23
+ outputs = model(input_ids)
24
+ next_token_logits = outputs[0][:, -1, :] / temperature
25
+ next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
26
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
27
+
28
+ if next_token.item() == enc.encode('\n')[0]:
29
+ break
30
+
31
+ generated_text = enc.decode(input_ids[0].tolist())
32
+ return generated_text
33
 
34
+ # Gradio interface
35
+ iface = gr.Interface(
36
  fn=generate_text,
37
  inputs=[
38
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
39
+ gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
40
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
41
  ],
42
+ outputs=gr.Textbox(label="Generated Text"),
43
  title="GPT-2 Text Generator",
44
+ description="Enter a prompt and generate text using a fine-tuned GPT-2 model."
45
  )
46
 
47
+ # Launch the app
48
+ iface.launch()