Ozgur Unlu
commited on
Commit
·
69cb436
1
Parent(s):
383989e
more fixes
Browse files
app.py
CHANGED
@@ -28,8 +28,8 @@ def load_models():
|
|
28 |
)
|
29 |
|
30 |
try:
|
31 |
-
# Use
|
32 |
-
generator_model = "
|
33 |
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model)
|
34 |
generator = AutoModelForCausalLM.from_pretrained(generator_model)
|
35 |
|
@@ -71,37 +71,28 @@ def generate_content(
|
|
71 |
):
|
72 |
char_limit = 280 if platform == "Twitter" else 500
|
73 |
|
74 |
-
#
|
75 |
-
prompt = f"""
|
76 |
|
77 |
-
|
78 |
-
- Highlight key benefits
|
79 |
-
- Include a clear call-to-action
|
80 |
-
- Be concise and engaging
|
81 |
-
- Match the {tone} tone
|
82 |
-
- Maximum {char_limit} characters
|
83 |
-
|
84 |
-
Product Details:
|
85 |
-
Name: {product_name}
|
86 |
-
Target Audience: {target_audience}
|
87 |
-
Key Benefits: {unique_benefits}
|
88 |
-
Main Features: {key_features}
|
89 |
-
|
90 |
-
Write the {platform} post here (do not include labels or prefixes):"""
|
91 |
|
|
|
|
|
|
|
92 |
try:
|
93 |
-
# Generate
|
94 |
-
inputs = generator_tokenizer(prompt, return_tensors="pt", max_length=
|
95 |
outputs = generator.generate(
|
96 |
inputs["input_ids"],
|
97 |
-
max_length=char_limit
|
98 |
-
num_return_sequences=
|
99 |
-
temperature=0.
|
100 |
-
top_p=0.
|
101 |
do_sample=True,
|
102 |
pad_token_id=generator_tokenizer.eos_token_id,
|
103 |
no_repeat_ngram_size=2,
|
104 |
-
min_length=
|
|
|
105 |
)
|
106 |
|
107 |
generated_texts = [generator_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
@@ -109,66 +100,50 @@ Write the {platform} post here (do not include labels or prefixes):"""
|
|
109 |
# Process and clean the generated content
|
110 |
filtered_content = []
|
111 |
for text in generated_texts:
|
112 |
-
# Extract only the post
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
116 |
-
#
|
117 |
-
post = post.
|
118 |
-
post = post.replace("Guidelines:", "")
|
119 |
-
post = post.replace("[Post]:", "")
|
120 |
-
post = post.replace("Post:", "")
|
121 |
|
122 |
-
# Skip if too short or contains prompt artifacts
|
123 |
-
if
|
124 |
-
"Task:" in post or
|
125 |
-
"Name:" in post or
|
126 |
-
"Target Audience:" in post):
|
127 |
continue
|
|
|
|
|
|
|
|
|
128 |
|
129 |
-
#
|
130 |
-
post
|
131 |
-
|
132 |
-
# Add platform-specific formatting
|
133 |
-
if platform == "Instagram":
|
134 |
-
words = product_name.lower().split()
|
135 |
-
relevant_tags = " ".join([f"#{word}" for word in words if len(word) > 2])
|
136 |
-
if len(post) + len(relevant_tags) + 2 <= char_limit:
|
137 |
-
post += f"\n{relevant_tags}"
|
138 |
|
139 |
# Check sentiment and safety
|
140 |
try:
|
141 |
sentiment = sentiment_analyzer(post)[0]
|
142 |
safety_check = content_checker(post)[0]
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
'safety_score': f"{float(safety_check.get('score', 0)):.2f}"
|
150 |
-
})
|
151 |
except Exception as e:
|
152 |
print(f"Error in content analysis: {str(e)}")
|
153 |
continue
|
154 |
|
155 |
-
# If no valid content was generated,
|
156 |
if not filtered_content:
|
157 |
-
|
158 |
-
|
159 |
-
benefits_list = unique_benefits.split(',')[0].strip()
|
160 |
|
161 |
-
|
162 |
-
f"Discover {product_name}! "
|
163 |
-
f"{product_description} "
|
164 |
-
f"Featuring {features_list}. "
|
165 |
-
f"{benefits_list}. "
|
166 |
-
f"Perfect for {target_audience}. "
|
167 |
-
f"Learn more today!"
|
168 |
-
)[:char_limit]
|
169 |
|
170 |
filtered_content.append({
|
171 |
-
'text':
|
172 |
'sentiment': 'positive',
|
173 |
'safety_score': '1.00'
|
174 |
})
|
@@ -178,7 +153,7 @@ Write the {platform} post here (do not include labels or prefixes):"""
|
|
178 |
except Exception as e:
|
179 |
print(f"Error in content generation: {str(e)}")
|
180 |
return [{
|
181 |
-
'text': f"
|
182 |
'sentiment': 'neutral',
|
183 |
'safety_score': '1.00'
|
184 |
}]
|
|
|
28 |
)
|
29 |
|
30 |
try:
|
31 |
+
# Use GPT-2 instead of DistilGPT-2
|
32 |
+
generator_model = "gpt2"
|
33 |
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model)
|
34 |
generator = AutoModelForCausalLM.from_pretrained(generator_model)
|
35 |
|
|
|
71 |
):
|
72 |
char_limit = 280 if platform == "Twitter" else 500
|
73 |
|
74 |
+
# Simpler, more direct prompt
|
75 |
+
prompt = f"""Marketing post for {platform}:
|
76 |
|
77 |
+
{product_name} helps {target_audience} by {unique_benefits}. {product_description}. Features include {key_features}.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
Marketing post in {tone} tone:
|
80 |
+
"""
|
81 |
+
|
82 |
try:
|
83 |
+
# Generate content with stricter parameters
|
84 |
+
inputs = generator_tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
|
85 |
outputs = generator.generate(
|
86 |
inputs["input_ids"],
|
87 |
+
max_length=char_limit // 2, # Keep it concise
|
88 |
+
num_return_sequences=2,
|
89 |
+
temperature=0.9,
|
90 |
+
top_p=0.85,
|
91 |
do_sample=True,
|
92 |
pad_token_id=generator_tokenizer.eos_token_id,
|
93 |
no_repeat_ngram_size=2,
|
94 |
+
min_length=30,
|
95 |
+
repetition_penalty=1.2
|
96 |
)
|
97 |
|
98 |
generated_texts = [generator_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
|
|
100 |
# Process and clean the generated content
|
101 |
filtered_content = []
|
102 |
for text in generated_texts:
|
103 |
+
# Extract only the part after "Marketing post in {tone} tone:"
|
104 |
+
try:
|
105 |
+
post = text.split("Marketing post in")[-1].split("tone:")[-1].strip()
|
106 |
+
except:
|
107 |
+
post = text
|
108 |
|
109 |
+
# Basic cleaning
|
110 |
+
post = ' '.join(post.split()) # Remove extra whitespace
|
|
|
|
|
|
|
111 |
|
112 |
+
# Skip if the post is too short or contains prompt artifacts
|
113 |
+
if len(post) < 20 or "Marketing post" in post or "tone:" in post:
|
|
|
|
|
|
|
114 |
continue
|
115 |
+
|
116 |
+
# Ensure it starts with product name if it's not already included
|
117 |
+
if product_name not in post:
|
118 |
+
post = f"{product_name}: {post}"
|
119 |
|
120 |
+
# Ensure there's a call to action
|
121 |
+
if "learn more" not in post.lower() and len(post) + 15 <= char_limit:
|
122 |
+
post += " Learn more today!"
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Check sentiment and safety
|
125 |
try:
|
126 |
sentiment = sentiment_analyzer(post)[0]
|
127 |
safety_check = content_checker(post)[0]
|
128 |
|
129 |
+
filtered_content.append({
|
130 |
+
'text': post[:char_limit],
|
131 |
+
'sentiment': sentiment['label'],
|
132 |
+
'safety_score': f"{float(safety_check.get('score', 0)):.2f}"
|
133 |
+
})
|
|
|
|
|
134 |
except Exception as e:
|
135 |
print(f"Error in content analysis: {str(e)}")
|
136 |
continue
|
137 |
|
138 |
+
# If no valid content was generated, use a structured fallback
|
139 |
if not filtered_content:
|
140 |
+
benefit = unique_benefits.split(',')[0].strip()
|
141 |
+
feature = key_features.split(',')[0].strip()
|
|
|
142 |
|
143 |
+
fallback_post = f"{product_name}: {product_description} {benefit}. Featuring {feature}. Perfect for {target_audience}. Learn more today!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
filtered_content.append({
|
146 |
+
'text': fallback_post[:char_limit],
|
147 |
'sentiment': 'positive',
|
148 |
'safety_score': '1.00'
|
149 |
})
|
|
|
153 |
except Exception as e:
|
154 |
print(f"Error in content generation: {str(e)}")
|
155 |
return [{
|
156 |
+
'text': f"{product_name}: {product_description[:100]}... Learn more today!",
|
157 |
'sentiment': 'neutral',
|
158 |
'safety_score': '1.00'
|
159 |
}]
|