Ozgur Unlu commited on
Commit
69cb436
·
1 Parent(s): 383989e

more fixes

Browse files
Files changed (1) hide show
  1. app.py +43 -68
app.py CHANGED
@@ -28,8 +28,8 @@ def load_models():
28
  )
29
 
30
  try:
31
- # Use a smaller text generation model
32
- generator_model = "distilgpt2" # Smaller than opt-350m
33
  generator_tokenizer = AutoTokenizer.from_pretrained(generator_model)
34
  generator = AutoModelForCausalLM.from_pretrained(generator_model)
35
 
@@ -71,37 +71,28 @@ def generate_content(
71
  ):
72
  char_limit = 280 if platform == "Twitter" else 500
73
 
74
- # Create a more structured prompt that works for any product type
75
- prompt = f"""Task: Write a {tone} {platform} post that promotes a product.
76
 
77
- Guidelines:
78
- - Highlight key benefits
79
- - Include a clear call-to-action
80
- - Be concise and engaging
81
- - Match the {tone} tone
82
- - Maximum {char_limit} characters
83
-
84
- Product Details:
85
- Name: {product_name}
86
- Target Audience: {target_audience}
87
- Key Benefits: {unique_benefits}
88
- Main Features: {key_features}
89
-
90
- Write the {platform} post here (do not include labels or prefixes):"""
91
 
 
 
 
92
  try:
93
- # Generate multiple variations
94
- inputs = generator_tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True)
95
  outputs = generator.generate(
96
  inputs["input_ids"],
97
- max_length=char_limit + 100, # Extra space for cleaning
98
- num_return_sequences=3,
99
- temperature=0.7,
100
- top_p=0.9,
101
  do_sample=True,
102
  pad_token_id=generator_tokenizer.eos_token_id,
103
  no_repeat_ngram_size=2,
104
- min_length=50,
 
105
  )
106
 
107
  generated_texts = [generator_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
@@ -109,66 +100,50 @@ Write the {platform} post here (do not include labels or prefixes):"""
109
  # Process and clean the generated content
110
  filtered_content = []
111
  for text in generated_texts:
112
- # Extract only the post content (remove prompt)
113
- post = text.split("Write the")[-1].split("post here")[-1].strip()
114
- post = post.strip(':" ').strip()
 
 
115
 
116
- # Clean up common issues
117
- post = post.replace("Product Details:", "")
118
- post = post.replace("Guidelines:", "")
119
- post = post.replace("[Post]:", "")
120
- post = post.replace("Post:", "")
121
 
122
- # Skip if too short or contains prompt artifacts
123
- if (len(post) < 30 or
124
- "Task:" in post or
125
- "Name:" in post or
126
- "Target Audience:" in post):
127
  continue
 
 
 
 
128
 
129
- # Truncate to character limit
130
- post = post[:char_limit]
131
-
132
- # Add platform-specific formatting
133
- if platform == "Instagram":
134
- words = product_name.lower().split()
135
- relevant_tags = " ".join([f"#{word}" for word in words if len(word) > 2])
136
- if len(post) + len(relevant_tags) + 2 <= char_limit:
137
- post += f"\n{relevant_tags}"
138
 
139
  # Check sentiment and safety
140
  try:
141
  sentiment = sentiment_analyzer(post)[0]
142
  safety_check = content_checker(post)[0]
143
 
144
- # Only add if content seems appropriate
145
- if sentiment['label'] != 'negative' and float(safety_check.get('score', 0)) > 0.7:
146
- filtered_content.append({
147
- 'text': post,
148
- 'sentiment': sentiment['label'],
149
- 'safety_score': f"{float(safety_check.get('score', 0)):.2f}"
150
- })
151
  except Exception as e:
152
  print(f"Error in content analysis: {str(e)}")
153
  continue
154
 
155
- # If no valid content was generated, create a generic but customized post
156
  if not filtered_content:
157
- # Create a generic but customized format that works for any product
158
- features_list = key_features.split(',')[0].strip()
159
- benefits_list = unique_benefits.split(',')[0].strip()
160
 
161
- generic_post = (
162
- f"Discover {product_name}! "
163
- f"{product_description} "
164
- f"Featuring {features_list}. "
165
- f"{benefits_list}. "
166
- f"Perfect for {target_audience}. "
167
- f"Learn more today!"
168
- )[:char_limit]
169
 
170
  filtered_content.append({
171
- 'text': generic_post,
172
  'sentiment': 'positive',
173
  'safety_score': '1.00'
174
  })
@@ -178,7 +153,7 @@ Write the {platform} post here (do not include labels or prefixes):"""
178
  except Exception as e:
179
  print(f"Error in content generation: {str(e)}")
180
  return [{
181
- 'text': f"Discover {product_name}! {product_description[:100]}... Learn more today!",
182
  'sentiment': 'neutral',
183
  'safety_score': '1.00'
184
  }]
 
28
  )
29
 
30
  try:
31
+ # Use GPT-2 instead of DistilGPT-2
32
+ generator_model = "gpt2"
33
  generator_tokenizer = AutoTokenizer.from_pretrained(generator_model)
34
  generator = AutoModelForCausalLM.from_pretrained(generator_model)
35
 
 
71
  ):
72
  char_limit = 280 if platform == "Twitter" else 500
73
 
74
+ # Simpler, more direct prompt
75
+ prompt = f"""Marketing post for {platform}:
76
 
77
+ {product_name} helps {target_audience} by {unique_benefits}. {product_description}. Features include {key_features}.
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ Marketing post in {tone} tone:
80
+ """
81
+
82
  try:
83
+ # Generate content with stricter parameters
84
+ inputs = generator_tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
85
  outputs = generator.generate(
86
  inputs["input_ids"],
87
+ max_length=char_limit // 2, # Keep it concise
88
+ num_return_sequences=2,
89
+ temperature=0.9,
90
+ top_p=0.85,
91
  do_sample=True,
92
  pad_token_id=generator_tokenizer.eos_token_id,
93
  no_repeat_ngram_size=2,
94
+ min_length=30,
95
+ repetition_penalty=1.2
96
  )
97
 
98
  generated_texts = [generator_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
100
  # Process and clean the generated content
101
  filtered_content = []
102
  for text in generated_texts:
103
+ # Extract only the part after "Marketing post in {tone} tone:"
104
+ try:
105
+ post = text.split("Marketing post in")[-1].split("tone:")[-1].strip()
106
+ except:
107
+ post = text
108
 
109
+ # Basic cleaning
110
+ post = ' '.join(post.split()) # Remove extra whitespace
 
 
 
111
 
112
+ # Skip if the post is too short or contains prompt artifacts
113
+ if len(post) < 20 or "Marketing post" in post or "tone:" in post:
 
 
 
114
  continue
115
+
116
+ # Ensure it starts with product name if it's not already included
117
+ if product_name not in post:
118
+ post = f"{product_name}: {post}"
119
 
120
+ # Ensure there's a call to action
121
+ if "learn more" not in post.lower() and len(post) + 15 <= char_limit:
122
+ post += " Learn more today!"
 
 
 
 
 
 
123
 
124
  # Check sentiment and safety
125
  try:
126
  sentiment = sentiment_analyzer(post)[0]
127
  safety_check = content_checker(post)[0]
128
 
129
+ filtered_content.append({
130
+ 'text': post[:char_limit],
131
+ 'sentiment': sentiment['label'],
132
+ 'safety_score': f"{float(safety_check.get('score', 0)):.2f}"
133
+ })
 
 
134
  except Exception as e:
135
  print(f"Error in content analysis: {str(e)}")
136
  continue
137
 
138
+ # If no valid content was generated, use a structured fallback
139
  if not filtered_content:
140
+ benefit = unique_benefits.split(',')[0].strip()
141
+ feature = key_features.split(',')[0].strip()
 
142
 
143
+ fallback_post = f"{product_name}: {product_description} {benefit}. Featuring {feature}. Perfect for {target_audience}. Learn more today!"
 
 
 
 
 
 
 
144
 
145
  filtered_content.append({
146
+ 'text': fallback_post[:char_limit],
147
  'sentiment': 'positive',
148
  'safety_score': '1.00'
149
  })
 
153
  except Exception as e:
154
  print(f"Error in content generation: {str(e)}")
155
  return [{
156
+ 'text': f"{product_name}: {product_description[:100]}... Learn more today!",
157
  'sentiment': 'neutral',
158
  'safety_score': '1.00'
159
  }]