|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import matplotlib.pyplot as plt |
|
import matplotlib.image as mpimg |
|
from embedding import model, docs_list, doc_embeddings, embedded_dict |
|
import openai |
|
from openai import OpenAI |
|
import os |
|
|
|
api_key = os.getenv("YOUR_OPENAI_KEY") |
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
def generate_meme_caption(topic): |
|
response = openai.ChatCompletion.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": f'''You are a loser, a meme lover who is always on the internet. |
|
You are a meme professional, so you will generate short captions for memes based on the following rules: |
|
Do not add emoji. |
|
Do not add any punctuation. |
|
Must be funny. |
|
Only if some topic related to friends, use "bro" or "blud". |
|
Answer from previous question can not be the same for this time prompt. |
|
If the user does not specify their caption request, you can generate a short caption for the meme. |
|
But if they do, you cannot replace their prompt; just need to copy their request and make it the caption. |
|
Now generate a caption based on this topic: {topic}. |
|
''' |
|
} |
|
], |
|
max_tokens=250, |
|
temperature=1.1 |
|
) |
|
|
|
caption = response['choices'][0]['message']['content'] |
|
return caption |
|
|
|
|
|
def create_meme(prompt): |
|
|
|
prompt_embedding = model.encode([prompt]) |
|
|
|
|
|
similarities = model.similarity(prompt_embedding, doc_embeddings) |
|
|
|
|
|
closest_text_idx = np.argmax(similarities) |
|
|
|
|
|
img_path = f'/kaggle/input/memedata/{embedded_dict[closest_text_idx]["filename"]}' |
|
img = mpimg.imread(img_path) |
|
|
|
|
|
white_space_height = 100 |
|
white_space = np.ones((white_space_height, img.shape[1], img.shape[2]), dtype=img.dtype) * 255 |
|
|
|
|
|
combined_img = np.vstack((white_space, img)) |
|
|
|
|
|
caption = generate_meme_caption(prompt) |
|
|
|
|
|
plt.imshow(combined_img) |
|
plt.axis('off') |
|
plt.text( |
|
x=combined_img.shape[1] / 2, |
|
y=white_space_height / 2, |
|
s=caption, |
|
fontsize=11, color='black', ha='center', va='top', fontweight='bold' |
|
) |
|
|
|
plt.show() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Meme Generator using Embeddings and GPT") |
|
|
|
prompt = gr.Textbox(label="Enter your meme prompt:") |
|
run_button = gr.Button("Generate Meme") |
|
|
|
result = gr.Plot(label="Generated Meme") |
|
|
|
def gradio_infer(prompt): |
|
plt.figure(figsize=(10, 8)) |
|
create_meme(prompt) |
|
plt.savefig('/tmp/meme.png') |
|
return plt |
|
|
|
run_button.click(gradio_infer, inputs=[prompt], outputs=[result]) |
|
|
|
demo.launch() |
|
|