Pranav0111 commited on
Commit
4a18f0c
1 Parent(s): 2a8ae3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -2,12 +2,20 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import random
 
 
 
 
5
 
6
  # Load model and tokenizer
7
  model_name = "gpt2-medium" # You can change this to your fine-tuned model
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
 
 
 
 
11
  # Sample prompts for different styles
12
  style_prompts = {
13
  "Classic": "Here's a romantic pickup line: ",
@@ -25,27 +33,31 @@ sentiment_adjustments = {
25
 
26
  def generate_pickup_line(style, sentiment, temperature=0.7, max_length=50):
27
  """Generate a pickup line based on selected style and sentiment."""
28
- prompt = style_prompts[style] + sentiment_adjustments[sentiment]
29
-
30
- # Set up the model inputs
31
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
32
-
33
- # Generate text
34
- with torch.no_grad():
35
- outputs = model.generate(
36
- inputs["input_ids"],
37
- max_length=max_length,
38
- temperature=temperature,
39
- num_return_sequences=1,
40
- pad_token_id=tokenizer.eos_token_id,
41
- do_sample=True
42
- )
43
-
44
- # Decode and clean up the generated text
45
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
- pickup_line = generated_text.replace(prompt, "").strip()
47
-
48
- return pickup_line
 
 
 
 
49
 
50
  def copy_to_clipboard(text):
51
  """Copy the generated pickup line to clipboard."""
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import random
5
+ import os
6
+
7
+ # Force CPU usage if you're having GPU issues
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
9
 
10
  # Load model and tokenizer
11
  model_name = "gpt2-medium" # You can change this to your fine-tuned model
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(model_name)
14
 
15
+ # Set up padding token
16
+ tokenizer.pad_token = tokenizer.eos_token
17
+ model.config.pad_token_id = model.config.eos_token_id
18
+
19
  # Sample prompts for different styles
20
  style_prompts = {
21
  "Classic": "Here's a romantic pickup line: ",
 
33
 
34
  def generate_pickup_line(style, sentiment, temperature=0.7, max_length=50):
35
  """Generate a pickup line based on selected style and sentiment."""
36
+ try:
37
+ prompt = style_prompts[style] + sentiment_adjustments[sentiment]
38
+
39
+ # Set up the model inputs
40
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
41
+
42
+ # Generate text
43
+ with torch.no_grad():
44
+ outputs = model.generate(
45
+ inputs["input_ids"],
46
+ max_length=max_length,
47
+ temperature=temperature,
48
+ num_return_sequences=1,
49
+ pad_token_id=tokenizer.pad_token_id,
50
+ do_sample=True,
51
+ no_repeat_ngram_size=2
52
+ )
53
+
54
+ # Decode and clean up the generated text
55
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ pickup_line = generated_text.replace(prompt, "").strip()
57
+
58
+ return pickup_line if pickup_line else "Please try again!"
59
+ except Exception as e:
60
+ return f"An error occurred. Please try again! Error: {str(e)}"
61
 
62
  def copy_to_clipboard(text):
63
  """Copy the generated pickup line to clipboard."""