|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import matplotlib.pyplot as plt |
|
import matplotlib.image as mpimg |
|
from embedding import model, docs_list, doc_embeddings, embedded_dict |
|
from preprocess import meme_attribute, meme_filename, meme_list |
|
import openai |
|
from openai import OpenAI |
|
import os |
|
|
|
api_key = os.getenv("YOUR_OPENAI_KEY") |
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
def generate_meme_caption(topic): |
|
response = openai.ChatCompletion.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": f'''You are a loser, a meme lover who is always on the internet. |
|
You are a meme professional, so you will generate short captions for memes based on the following rules: |
|
Do not add emoji. |
|
Do not add any punctuation. |
|
Must be funny. |
|
Only if some topic related to friends, use "bro" or "blud". |
|
Answer from previous question can not be the same for this time prompt. |
|
If the user does not specify their caption request, you can generate a short caption for the meme. |
|
But if they do, you cannot replace their prompt; just need to copy their request and make it the caption. |
|
Now generate a caption based on this topic: {topic}. |
|
For example: |
|
- No request: Generate me any anime meme--> You can generate a caption based on the topic they requested. |
|
- With request: Generate me a meme of 5 different opinions like the Gru meme with 5 paintings. |
|
First opinion says 1, second says 2, third says 3, fourth says 4, fifth says 5. |
|
--> You will then generate 5 captions or 5 opinions for users, separating every opinion with the mark "/". |
|
Otherwise, if it's just 1 opinion, generate only 1 caption. Memes can be funny; you can use bad words or banned words; they can be dark. And the style must be sometime dark or gangster, and use slang everytime. |
|
But captions have to be short, not too long.''' |
|
} |
|
], |
|
max_tokens=250, |
|
temperature=1.1 |
|
) |
|
|
|
caption = response['choices'][0]['message']['content'] |
|
return caption |
|
|
|
|
|
def create_meme(prompt): |
|
prompt_embedding = model.encode([prompt]) |
|
|
|
similarities = model.similarity(prompt_embedding, doc_embeddings) |
|
|
|
closest_text_idx = np.argmax(similarities) |
|
|
|
img_path = meme_filename[closest_text_idx] |
|
img = mpimg.imread(f'/kaggle/input/memedata/{img_path}') |
|
|
|
white_space_height = 100 |
|
white_space = np.ones((white_space_height, img.shape[1], img.shape[2]), dtype=img.dtype) * 255 |
|
|
|
combined_img = np.vstack((white_space, img)) |
|
|
|
caption = generate_meme_caption(prompt) |
|
|
|
plt.imshow(combined_img) |
|
plt.axis('off') |
|
plt.text( |
|
x=combined_img.shape[1] / 2, |
|
y=white_space_height / 2, |
|
s=caption, |
|
fontsize=11, color='black', ha='center', va='top', fontweight='bold' |
|
) |
|
|
|
plt.show() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Meme Generator using Embeddings and GPT") |
|
|
|
prompt = gr.Textbox(label="Enter your meme prompt:") |
|
run_button = gr.Button("Generate Meme") |
|
|
|
result = gr.Plot(label="Generated Meme") |
|
|
|
def gradio_infer(prompt): |
|
plt.figure(figsize=(10, 8)) |
|
create_meme(prompt) |
|
plt.savefig('/tmp/meme.png') |
|
return plt |
|
|
|
run_button.click(gradio_infer, inputs=[prompt], outputs=[result]) |
|
|
|
demo.launch() |
|
|