|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
import nltk |
|
from datetime import datetime, timedelta |
|
import requests |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
|
|
def load_models(): |
|
try: |
|
|
|
generator_model = "facebook/opt-350m" |
|
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model) |
|
generator = AutoModelForCausalLM.from_pretrained(generator_model) |
|
|
|
|
|
sentiment_analyzer = pipeline( |
|
"sentiment-analysis", |
|
model="finiteautomata/bertweet-base-sentiment-analysis" |
|
) |
|
|
|
|
|
content_checker = pipeline( |
|
"text-classification", |
|
model="facebook/roberta-hate-speech-dynabench-r4-target" |
|
) |
|
|
|
return generator_tokenizer, generator, sentiment_analyzer, content_checker |
|
except Exception as e: |
|
print(f"Error loading models: {str(e)}") |
|
raise |
|
|
|
|
|
def fetch_recent_news(query, num_articles=3): |
|
base_url = "https://news.google.com/rss/search" |
|
params = { |
|
'q': query, |
|
'hl': 'en-US', |
|
'gl': 'US', |
|
'ceid': 'US:en' |
|
} |
|
|
|
try: |
|
response = requests.get(base_url, params=params, timeout=5) |
|
soup = BeautifulSoup(response.content, 'xml') |
|
items = soup.find_all('item', limit=num_articles) |
|
|
|
news_data = [] |
|
for item in items: |
|
try: |
|
news_data.append({ |
|
'title': item.title.text, |
|
'description': item.description.text if item.description else "" |
|
}) |
|
except: |
|
continue |
|
|
|
return news_data |
|
except Exception as e: |
|
return [{'title': f'Using default context due to error: {str(e)}', 'description': ''}] |
|
|
|
|
|
def generate_content( |
|
product_name, |
|
product_description, |
|
target_audience, |
|
key_features, |
|
unique_benefits, |
|
platform, |
|
tone, |
|
generator_tokenizer, |
|
generator, |
|
sentiment_analyzer, |
|
content_checker |
|
): |
|
|
|
char_limit = 280 if platform == "Twitter" else 500 |
|
|
|
|
|
news_data = fetch_recent_news(f"{product_name} {target_audience}") |
|
news_context = "\n".join([f"Recent context: {item['title']}" for item in news_data]) |
|
|
|
|
|
prompt = f""" |
|
Create a {platform} post with these requirements: |
|
- Product Name: {product_name} |
|
- Description: {product_description} |
|
- Target Audience: {target_audience} |
|
- Key Features: {key_features} |
|
- Unique Benefits: {unique_benefits} |
|
- Tone: {tone} |
|
- Maximum Length: {char_limit} characters |
|
|
|
Recent Market Context: |
|
{news_context} |
|
|
|
Generate a compelling {platform} post that highlights the product's benefits while maintaining a {tone} tone. |
|
""" |
|
|
|
try: |
|
|
|
inputs = generator_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = generator.generate( |
|
inputs["input_ids"], |
|
max_length=char_limit + len(prompt), |
|
num_return_sequences=3, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
) |
|
|
|
generated_texts = [generator_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
|
|
filtered_content = [] |
|
for text in generated_texts: |
|
|
|
text = text.replace(prompt, "").strip() |
|
|
|
|
|
if len(text) < 10 or len(text) > char_limit: |
|
continue |
|
|
|
|
|
sentiment = sentiment_analyzer(text)[0] |
|
|
|
|
|
safety_check = content_checker(text)[0] |
|
|
|
|
|
if ( |
|
sentiment['label'] != 'negative' and |
|
safety_check['label'] == 'not_hate' and |
|
len(text) <= char_limit |
|
): |
|
filtered_content.append({ |
|
'text': text, |
|
'sentiment': sentiment['label'], |
|
'safety_score': f"{float(safety_check['score']):.2f}" |
|
}) |
|
|
|
return filtered_content |
|
except Exception as e: |
|
print(f"Error generating content: {str(e)}") |
|
return [] |
|
|
|
|
|
def create_interface(): |
|
generator_tokenizer, generator, sentiment_analyzer, content_checker = load_models() |
|
|
|
def process_input( |
|
product_name, |
|
product_description, |
|
target_audience, |
|
key_features, |
|
unique_benefits, |
|
platform, |
|
tone |
|
): |
|
try: |
|
results = generate_content( |
|
product_name, |
|
product_description, |
|
target_audience, |
|
key_features, |
|
unique_benefits, |
|
platform, |
|
tone, |
|
generator_tokenizer, |
|
generator, |
|
sentiment_analyzer, |
|
content_checker |
|
) |
|
|
|
if not results: |
|
return "No suitable content generated. Please try again with different parameters." |
|
|
|
output = "" |
|
for i, content in enumerate(results, 1): |
|
output += f"\nVersion {i}:\n" |
|
output += f"Content: {content['text']}\n" |
|
output += f"Sentiment: {content['sentiment']}\n" |
|
output += f"Safety Score: {content['safety_score']}\n" |
|
output += "-" * 50 + "\n" |
|
|
|
return output |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_input, |
|
inputs=[ |
|
gr.Textbox(label="Product Name", placeholder="Enter product name"), |
|
gr.Textbox(label="Product Description", lines=3, placeholder="Brief description of your product"), |
|
gr.Textbox(label="Target Audience", placeholder="Who is this product for?"), |
|
gr.Textbox(label="Key Features", lines=2, placeholder="Main features of your product"), |
|
gr.Textbox(label="Unique Benefits", lines=2, placeholder="What makes your product special?"), |
|
gr.Radio( |
|
choices=["Twitter", "Instagram"], |
|
label="Platform", |
|
value="Twitter" |
|
), |
|
gr.Textbox(label="Tone", placeholder="e.g., professional, casual, friendly"), |
|
], |
|
outputs=gr.Textbox(label="Generated Content", lines=10), |
|
title="Ethimar - AI Marketing Content Generator", |
|
description="Generate ethical marketing content with AI-powered insights", |
|
theme="default", |
|
examples=[ |
|
[ |
|
"EcoBottle", |
|
"Sustainable water bottle made from recycled ocean plastic", |
|
"Environmentally conscious young professionals", |
|
"100% recycled materials, Insulated design, Leak-proof", |
|
"Helps clean oceans, Keeps drinks cold for 24 hours", |
|
"Twitter", |
|
"professional" |
|
] |
|
] |
|
) |
|
|
|
return iface |
|
|
|
|
|
if __name__ == "__main__": |
|
iface = create_interface() |
|
iface.launch() |