anshu-man853 commited on
Commit
0361c00
·
verified ·
1 Parent(s): e2252be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -1,28 +1,50 @@
1
  import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
3
 
4
  # Load the GPT-2 model and tokenizer
5
  model_name = "gpt2"
6
- model = GPT2LMHeadModel.from_pretrained(model_name)
 
 
 
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
 
9
  # Define the sentence completion function
10
  def complete_sentence(sentence):
11
- input_ids = tokenizer.encode(sentence, return_tensors="pt")
12
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
13
- completed_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
14
- return completed_sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Create the Gradio interface
17
  iface = gr.Interface(
18
- fn=complete_sentence,
19
- inputs="text",
20
- outputs="text",
21
- title="Sentence Completion",
22
- description="Enter a sentence to complete",
23
- example="I love to"
24
  )
25
 
26
  # Launch the Gradio interface
27
  if __name__ == "__main__":
28
- iface.launch()
 
1
  import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ import torch
4
 
5
  # Load the GPT-2 model and tokenizer
6
  model_name = "gpt2"
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # Load model and tokenizer, move model to the correct device
10
+ model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
 
13
  # Define the sentence completion function
14
  def complete_sentence(sentence):
15
+ if not sentence.strip():
16
+ return "Please enter a valid input."
17
+ try:
18
+ # Encode the input sentence
19
+ input_ids = tokenizer.encode(sentence, return_tensors="pt").to(device)
20
+
21
+ # Generate completion
22
+ output = model.generate(
23
+ input_ids,
24
+ max_length=50,
25
+ num_return_sequences=1,
26
+ no_repeat_ngram_size=2,
27
+ temperature=0.7,
28
+ top_p=0.9,
29
+ do_sample=True,
30
+ )
31
+
32
+ # Decode the generated sentence
33
+ completed_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
34
+ return completed_sentence
35
+ except Exception as e:
36
+ return f"An error occurred: {str(e)}"
37
 
38
  # Create the Gradio interface
39
  iface = gr.Interface(
40
+ fn=complete_sentence,
41
+ inputs="text",
42
+ outputs="text",
43
+ title="Sentence Completion",
44
+ description="Enter a sentence to complete.",
45
+ examples=["I love to", "The future of AI is", "Once upon a time"],
46
  )
47
 
48
  # Launch the Gradio interface
49
  if __name__ == "__main__":
50
+ iface.launch()